github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bind_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "encoding/binary" 10 "io" 11 "net" 12 "net/netip" 13 "strconv" 14 "sync" 15 "sync/atomic" 16 "unsafe" 17 18 "golang.org/x/sys/windows" 19 20 "github.com/amnezia-vpn/amneziawg-go/conn/winrio" 21 ) 22 23 const ( 24 packetsPerRing = 1024 25 bytesPerPacket = 2048 - 32 26 receiveSpins = 15 27 ) 28 29 type ringPacket struct { 30 addr WinRingEndpoint 31 data [bytesPerPacket]byte 32 } 33 34 type ringBuffer struct { 35 packets uintptr 36 head, tail uint32 37 id winrio.BufferId 38 iocp windows.Handle 39 isFull bool 40 cq winrio.Cq 41 mu sync.Mutex 42 overlapped windows.Overlapped 43 } 44 45 func (rb *ringBuffer) Push() *ringPacket { 46 for rb.isFull { 47 panic("ring is full") 48 } 49 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 50 rb.tail += 1 51 if rb.tail%packetsPerRing == rb.head%packetsPerRing { 52 rb.isFull = true 53 } 54 return ret 55 } 56 57 func (rb *ringBuffer) Return(count uint32) { 58 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { 59 return 60 } 61 rb.head += count 62 rb.isFull = false 63 } 64 65 type afWinRingBind struct { 66 sock windows.Handle 67 rx, tx ringBuffer 68 rq winrio.Rq 69 mu sync.Mutex 70 blackhole bool 71 } 72 73 // WinRingBind uses Windows registered I/O for fast ring buffered networking. 74 type WinRingBind struct { 75 v4, v6 afWinRingBind 76 mu sync.RWMutex 77 isOpen atomic.Uint32 // 0, 1, or 2 78 } 79 80 func NewDefaultBind() Bind { return NewWinRingBind() } 81 82 func NewWinRingBind() Bind { 83 if !winrio.Initialize() { 84 return NewStdNetBind() 85 } 86 return new(WinRingBind) 87 } 88 89 type WinRingEndpoint struct { 90 family uint16 91 data [30]byte 92 } 93 94 var ( 95 _ Bind = (*WinRingBind)(nil) 96 _ Endpoint = (*WinRingEndpoint)(nil) 97 ) 98 99 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { 100 host, port, err := net.SplitHostPort(s) 101 if err != nil { 102 return nil, err 103 } 104 host16, err := windows.UTF16PtrFromString(host) 105 if err != nil { 106 return nil, err 107 } 108 port16, err := windows.UTF16PtrFromString(port) 109 if err != nil { 110 return nil, err 111 } 112 hints := windows.AddrinfoW{ 113 Flags: windows.AI_NUMERICHOST, 114 Family: windows.AF_UNSPEC, 115 Socktype: windows.SOCK_DGRAM, 116 Protocol: windows.IPPROTO_UDP, 117 } 118 var addrinfo *windows.AddrinfoW 119 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) 120 if err != nil { 121 return nil, err 122 } 123 defer windows.FreeAddrInfoW(addrinfo) 124 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { 125 return nil, windows.ERROR_INVALID_ADDRESS 126 } 127 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte 128 copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) 129 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil 130 } 131 132 func (*WinRingEndpoint) ClearSrc() {} 133 134 func (e *WinRingEndpoint) DstIP() netip.Addr { 135 switch e.family { 136 case windows.AF_INET: 137 return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) 138 case windows.AF_INET6: 139 return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) 140 } 141 return netip.Addr{} 142 } 143 144 func (e *WinRingEndpoint) SrcIP() netip.Addr { 145 return netip.Addr{} // not supported 146 } 147 148 func (e *WinRingEndpoint) DstToBytes() []byte { 149 switch e.family { 150 case windows.AF_INET: 151 b := make([]byte, 0, 6) 152 b = append(b, e.data[2:6]...) 153 b = append(b, e.data[1], e.data[0]) 154 return b 155 case windows.AF_INET6: 156 b := make([]byte, 0, 18) 157 b = append(b, e.data[6:22]...) 158 b = append(b, e.data[1], e.data[0]) 159 return b 160 } 161 return nil 162 } 163 164 func (e *WinRingEndpoint) DstToString() string { 165 switch e.family { 166 case windows.AF_INET: 167 return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() 168 case windows.AF_INET6: 169 var zone string 170 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { 171 zone = strconv.FormatUint(uint64(scope), 10) 172 } 173 return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() 174 } 175 return "" 176 } 177 178 func (e *WinRingEndpoint) SrcToString() string { 179 return "" 180 } 181 182 func (ring *ringBuffer) CloseAndZero() { 183 if ring.cq != 0 { 184 winrio.CloseCompletionQueue(ring.cq) 185 ring.cq = 0 186 } 187 if ring.iocp != 0 { 188 windows.CloseHandle(ring.iocp) 189 ring.iocp = 0 190 } 191 if ring.id != 0 { 192 winrio.DeregisterBuffer(ring.id) 193 ring.id = 0 194 } 195 if ring.packets != 0 { 196 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 197 ring.packets = 0 198 } 199 ring.head = 0 200 ring.tail = 0 201 ring.isFull = false 202 } 203 204 func (bind *afWinRingBind) CloseAndZero() { 205 bind.rx.CloseAndZero() 206 bind.tx.CloseAndZero() 207 if bind.sock != 0 { 208 windows.CloseHandle(bind.sock) 209 bind.sock = 0 210 } 211 bind.blackhole = false 212 } 213 214 func (bind *WinRingBind) closeAndZero() { 215 bind.isOpen.Store(0) 216 bind.v4.CloseAndZero() 217 bind.v6.CloseAndZero() 218 } 219 220 func (ring *ringBuffer) Open() error { 221 var err error 222 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 223 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 224 if err != nil { 225 return err 226 } 227 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 228 if err != nil { 229 return err 230 } 231 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 232 if err != nil { 233 return err 234 } 235 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 236 if err != nil { 237 return err 238 } 239 return nil 240 } 241 242 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { 243 var err error 244 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 245 if err != nil { 246 return nil, err 247 } 248 err = bind.rx.Open() 249 if err != nil { 250 return nil, err 251 } 252 err = bind.tx.Open() 253 if err != nil { 254 return nil, err 255 } 256 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) 257 if err != nil { 258 return nil, err 259 } 260 err = windows.Bind(bind.sock, sa) 261 if err != nil { 262 return nil, err 263 } 264 sa, err = windows.Getsockname(bind.sock) 265 if err != nil { 266 return nil, err 267 } 268 return sa, nil 269 } 270 271 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { 272 bind.mu.Lock() 273 defer bind.mu.Unlock() 274 defer func() { 275 if err != nil { 276 bind.closeAndZero() 277 } 278 }() 279 if bind.isOpen.Load() != 0 { 280 return nil, 0, ErrBindAlreadyOpen 281 } 282 var sa windows.Sockaddr 283 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) 284 if err != nil { 285 return nil, 0, err 286 } 287 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) 288 if err != nil { 289 return nil, 0, err 290 } 291 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) 292 for i := 0; i < packetsPerRing; i++ { 293 err = bind.v4.InsertReceiveRequest() 294 if err != nil { 295 return nil, 0, err 296 } 297 err = bind.v6.InsertReceiveRequest() 298 if err != nil { 299 return nil, 0, err 300 } 301 } 302 bind.isOpen.Store(1) 303 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err 304 } 305 306 func (bind *WinRingBind) Close() error { 307 bind.mu.RLock() 308 if bind.isOpen.Load() != 1 { 309 bind.mu.RUnlock() 310 return nil 311 } 312 bind.isOpen.Store(2) 313 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) 314 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) 315 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) 316 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) 317 bind.mu.RUnlock() 318 bind.mu.Lock() 319 defer bind.mu.Unlock() 320 bind.closeAndZero() 321 return nil 322 } 323 324 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 325 // rename the IdealBatchSize constant to BatchSize. 326 func (bind *WinRingBind) BatchSize() int { 327 // TODO: implement batching in and out of the ring 328 return 1 329 } 330 331 func (bind *WinRingBind) GetOffloadInfo() string { 332 return "" 333 } 334 335 func (bind *WinRingBind) SetMark(mark uint32) error { 336 return nil 337 } 338 339 func (bind *afWinRingBind) InsertReceiveRequest() error { 340 packet := bind.rx.Push() 341 dataBuffer := &winrio.Buffer{ 342 Id: bind.rx.id, 343 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), 344 Length: uint32(len(packet.data)), 345 } 346 addressBuffer := &winrio.Buffer{ 347 Id: bind.rx.id, 348 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), 349 Length: uint32(unsafe.Sizeof(packet.addr)), 350 } 351 bind.mu.Lock() 352 defer bind.mu.Unlock() 353 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 354 } 355 356 //go:linkname procyield runtime.procyield 357 func procyield(cycles uint32) 358 359 func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { 360 if isOpen.Load() != 1 { 361 return 0, nil, net.ErrClosed 362 } 363 bind.rx.mu.Lock() 364 defer bind.rx.mu.Unlock() 365 366 var err error 367 var count uint32 368 var results [1]winrio.Result 369 retry: 370 count = 0 371 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 372 if tries > 0 { 373 if isOpen.Load() != 1 { 374 return 0, nil, net.ErrClosed 375 } 376 procyield(1) 377 } 378 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 379 } 380 if count == 0 { 381 err = winrio.Notify(bind.rx.cq) 382 if err != nil { 383 return 0, nil, err 384 } 385 var bytes uint32 386 var key uintptr 387 var overlapped *windows.Overlapped 388 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 389 if err != nil { 390 return 0, nil, err 391 } 392 if isOpen.Load() != 1 { 393 return 0, nil, net.ErrClosed 394 } 395 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 396 if count == 0 { 397 return 0, nil, io.ErrNoProgress 398 } 399 } 400 bind.rx.Return(1) 401 err = bind.InsertReceiveRequest() 402 if err != nil { 403 return 0, nil, err 404 } 405 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 406 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 407 // attacker bandwidth, just like the rest of the receive path. 408 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 409 if isOpen.Load() != 1 { 410 return 0, nil, net.ErrClosed 411 } 412 goto retry 413 } 414 if results[0].Status != 0 { 415 return 0, nil, windows.Errno(results[0].Status) 416 } 417 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 418 ep := packet.addr 419 n := copy(buf, packet.data[:results[0].BytesTransferred]) 420 return n, &ep, nil 421 } 422 423 func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { 424 bind.mu.RLock() 425 defer bind.mu.RUnlock() 426 n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) 427 sizes[0] = n 428 eps[0] = ep 429 return 1, err 430 } 431 432 func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { 433 bind.mu.RLock() 434 defer bind.mu.RUnlock() 435 n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) 436 sizes[0] = n 437 eps[0] = ep 438 return 1, err 439 } 440 441 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { 442 if isOpen.Load() != 1 { 443 return net.ErrClosed 444 } 445 if len(buf) > bytesPerPacket { 446 return io.ErrShortBuffer 447 } 448 bind.tx.mu.Lock() 449 defer bind.tx.mu.Unlock() 450 var results [packetsPerRing]winrio.Result 451 count := winrio.DequeueCompletion(bind.tx.cq, results[:]) 452 if count == 0 && bind.tx.isFull { 453 err := winrio.Notify(bind.tx.cq) 454 if err != nil { 455 return err 456 } 457 var bytes uint32 458 var key uintptr 459 var overlapped *windows.Overlapped 460 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 461 if err != nil { 462 return err 463 } 464 if isOpen.Load() != 1 { 465 return net.ErrClosed 466 } 467 count = winrio.DequeueCompletion(bind.tx.cq, results[:]) 468 if count == 0 { 469 return io.ErrNoProgress 470 } 471 } 472 if count > 0 { 473 bind.tx.Return(count) 474 } 475 packet := bind.tx.Push() 476 packet.addr = *nend 477 copy(packet.data[:], buf) 478 dataBuffer := &winrio.Buffer{ 479 Id: bind.tx.id, 480 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), 481 Length: uint32(len(buf)), 482 } 483 addressBuffer := &winrio.Buffer{ 484 Id: bind.tx.id, 485 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), 486 Length: uint32(unsafe.Sizeof(packet.addr)), 487 } 488 bind.mu.Lock() 489 defer bind.mu.Unlock() 490 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 491 } 492 493 func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { 494 nend, ok := endpoint.(*WinRingEndpoint) 495 if !ok { 496 return ErrWrongEndpointType 497 } 498 bind.mu.RLock() 499 defer bind.mu.RUnlock() 500 for _, buf := range bufs { 501 switch nend.family { 502 case windows.AF_INET: 503 if bind.v4.blackhole { 504 continue 505 } 506 if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { 507 return err 508 } 509 case windows.AF_INET6: 510 if bind.v6.blackhole { 511 continue 512 } 513 if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { 514 return err 515 } 516 } 517 } 518 return nil 519 } 520 521 func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 522 s.mu.Lock() 523 defer s.mu.Unlock() 524 sysconn, err := s.ipv4.SyscallConn() 525 if err != nil { 526 return err 527 } 528 err2 := sysconn.Control(func(fd uintptr) { 529 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) 530 }) 531 if err2 != nil { 532 return err2 533 } 534 if err != nil { 535 return err 536 } 537 s.blackhole4 = blackhole 538 return nil 539 } 540 541 func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 542 s.mu.Lock() 543 defer s.mu.Unlock() 544 sysconn, err := s.ipv6.SyscallConn() 545 if err != nil { 546 return err 547 } 548 err2 := sysconn.Control(func(fd uintptr) { 549 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) 550 }) 551 if err2 != nil { 552 return err2 553 } 554 if err != nil { 555 return err 556 } 557 s.blackhole6 = blackhole 558 return nil 559 } 560 561 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 562 bind.mu.RLock() 563 defer bind.mu.RUnlock() 564 if bind.isOpen.Load() != 1 { 565 return net.ErrClosed 566 } 567 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) 568 if err != nil { 569 return err 570 } 571 bind.v4.blackhole = blackhole 572 return nil 573 } 574 575 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 576 bind.mu.RLock() 577 defer bind.mu.RUnlock() 578 if bind.isOpen.Load() != 1 { 579 return net.ErrClosed 580 } 581 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) 582 if err != nil { 583 return err 584 } 585 bind.v6.blackhole = blackhole 586 return nil 587 } 588 589 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { 590 const IP_UNICAST_IF = 31 591 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ 592 var bytes [4]byte 593 binary.BigEndian.PutUint32(bytes[:], interfaceIndex) 594 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) 595 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) 596 if err != nil { 597 return err 598 } 599 return nil 600 } 601 602 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { 603 const IPV6_UNICAST_IF = 31 604 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) 605 }