github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/bind_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "encoding/binary" 10 "io" 11 "net" 12 "strconv" 13 "sync" 14 "sync/atomic" 15 "unsafe" 16 17 "golang.org/x/sys/windows" 18 19 "github.com/bugfan/wireguard-go/conn/winrio" 20 ) 21 22 const ( 23 packetsPerRing = 1024 24 bytesPerPacket = 2048 - 32 25 receiveSpins = 15 26 ) 27 28 type ringPacket struct { 29 addr WinRingEndpoint 30 data [bytesPerPacket]byte 31 } 32 33 type ringBuffer struct { 34 packets uintptr 35 head, tail uint32 36 id winrio.BufferId 37 iocp windows.Handle 38 isFull bool 39 cq winrio.Cq 40 mu sync.Mutex 41 overlapped windows.Overlapped 42 } 43 44 func (rb *ringBuffer) Push() *ringPacket { 45 for rb.isFull { 46 panic("ring is full") 47 } 48 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 49 rb.tail += 1 50 if rb.tail%packetsPerRing == rb.head%packetsPerRing { 51 rb.isFull = true 52 } 53 return ret 54 } 55 56 func (rb *ringBuffer) Return(count uint32) { 57 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { 58 return 59 } 60 rb.head += count 61 rb.isFull = false 62 } 63 64 type afWinRingBind struct { 65 sock windows.Handle 66 rx, tx ringBuffer 67 rq winrio.Rq 68 mu sync.Mutex 69 blackhole bool 70 } 71 72 // WinRingBind uses Windows registered I/O for fast ring buffered networking. 73 type WinRingBind struct { 74 v4, v6 afWinRingBind 75 mu sync.RWMutex 76 isOpen uint32 77 } 78 79 func NewDefaultBind() Bind { return NewWinRingBind() } 80 81 func NewWinRingBind() Bind { 82 if !winrio.Initialize() { 83 return NewStdNetBind() 84 } 85 return new(WinRingBind) 86 } 87 88 type WinRingEndpoint struct { 89 family uint16 90 data [30]byte 91 } 92 93 var _ Bind = (*WinRingBind)(nil) 94 var _ Endpoint = (*WinRingEndpoint)(nil) 95 96 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { 97 host, port, err := net.SplitHostPort(s) 98 if err != nil { 99 return nil, err 100 } 101 host16, err := windows.UTF16PtrFromString(host) 102 if err != nil { 103 return nil, err 104 } 105 port16, err := windows.UTF16PtrFromString(port) 106 if err != nil { 107 return nil, err 108 } 109 hints := windows.AddrinfoW{ 110 Flags: windows.AI_NUMERICHOST, 111 Family: windows.AF_UNSPEC, 112 Socktype: windows.SOCK_DGRAM, 113 Protocol: windows.IPPROTO_UDP, 114 } 115 var addrinfo *windows.AddrinfoW 116 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) 117 if err != nil { 118 return nil, err 119 } 120 defer windows.FreeAddrInfoW(addrinfo) 121 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { 122 return nil, windows.ERROR_INVALID_ADDRESS 123 } 124 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte 125 copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) 126 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil 127 } 128 129 func (*WinRingEndpoint) ClearSrc() {} 130 131 func (e *WinRingEndpoint) DstIP() net.IP { 132 switch e.family { 133 case windows.AF_INET: 134 return append([]byte{}, e.data[2:6]...) 135 case windows.AF_INET6: 136 return append([]byte{}, e.data[6:22]...) 137 } 138 return nil 139 } 140 141 func (e *WinRingEndpoint) SrcIP() net.IP { 142 return nil // not supported 143 } 144 145 func (e *WinRingEndpoint) DstToBytes() []byte { 146 switch e.family { 147 case windows.AF_INET: 148 b := make([]byte, 0, 6) 149 b = append(b, e.data[2:6]...) 150 b = append(b, e.data[1], e.data[0]) 151 return b 152 case windows.AF_INET6: 153 b := make([]byte, 0, 18) 154 b = append(b, e.data[6:22]...) 155 b = append(b, e.data[1], e.data[0]) 156 return b 157 } 158 return nil 159 } 160 161 func (e *WinRingEndpoint) DstToString() string { 162 switch e.family { 163 case windows.AF_INET: 164 addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} 165 return addr.String() 166 case windows.AF_INET6: 167 var zone string 168 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { 169 zone = strconv.FormatUint(uint64(scope), 10) 170 } 171 addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} 172 return addr.String() 173 } 174 return "" 175 } 176 177 func (e *WinRingEndpoint) SrcToString() string { 178 return "" 179 } 180 181 func (ring *ringBuffer) CloseAndZero() { 182 if ring.cq != 0 { 183 winrio.CloseCompletionQueue(ring.cq) 184 ring.cq = 0 185 } 186 if ring.iocp != 0 { 187 windows.CloseHandle(ring.iocp) 188 ring.iocp = 0 189 } 190 if ring.id != 0 { 191 winrio.DeregisterBuffer(ring.id) 192 ring.id = 0 193 } 194 if ring.packets != 0 { 195 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 196 ring.packets = 0 197 } 198 ring.head = 0 199 ring.tail = 0 200 ring.isFull = false 201 } 202 203 func (bind *afWinRingBind) CloseAndZero() { 204 bind.rx.CloseAndZero() 205 bind.tx.CloseAndZero() 206 if bind.sock != 0 { 207 windows.CloseHandle(bind.sock) 208 bind.sock = 0 209 } 210 bind.blackhole = false 211 } 212 213 func (bind *WinRingBind) closeAndZero() { 214 atomic.StoreUint32(&bind.isOpen, 0) 215 bind.v4.CloseAndZero() 216 bind.v6.CloseAndZero() 217 } 218 219 func (ring *ringBuffer) Open() error { 220 var err error 221 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 222 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 223 if err != nil { 224 return err 225 } 226 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 227 if err != nil { 228 return err 229 } 230 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 231 if err != nil { 232 return err 233 } 234 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 235 if err != nil { 236 return err 237 } 238 return nil 239 } 240 241 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { 242 var err error 243 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 244 if err != nil { 245 return nil, err 246 } 247 err = bind.rx.Open() 248 if err != nil { 249 return nil, err 250 } 251 err = bind.tx.Open() 252 if err != nil { 253 return nil, err 254 } 255 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) 256 if err != nil { 257 return nil, err 258 } 259 err = windows.Bind(bind.sock, sa) 260 if err != nil { 261 return nil, err 262 } 263 sa, err = windows.Getsockname(bind.sock) 264 if err != nil { 265 return nil, err 266 } 267 return sa, nil 268 } 269 270 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { 271 bind.mu.Lock() 272 defer bind.mu.Unlock() 273 defer func() { 274 if err != nil { 275 bind.closeAndZero() 276 } 277 }() 278 if atomic.LoadUint32(&bind.isOpen) != 0 { 279 return nil, 0, ErrBindAlreadyOpen 280 } 281 var sa windows.Sockaddr 282 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) 283 if err != nil { 284 return nil, 0, err 285 } 286 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) 287 if err != nil { 288 return nil, 0, err 289 } 290 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) 291 for i := 0; i < packetsPerRing; i++ { 292 err = bind.v4.InsertReceiveRequest() 293 if err != nil { 294 return nil, 0, err 295 } 296 err = bind.v6.InsertReceiveRequest() 297 if err != nil { 298 return nil, 0, err 299 } 300 } 301 atomic.StoreUint32(&bind.isOpen, 1) 302 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err 303 } 304 305 func (bind *WinRingBind) Close() error { 306 bind.mu.RLock() 307 if atomic.LoadUint32(&bind.isOpen) != 1 { 308 bind.mu.RUnlock() 309 return nil 310 } 311 atomic.StoreUint32(&bind.isOpen, 2) 312 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) 313 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) 314 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) 315 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) 316 bind.mu.RUnlock() 317 bind.mu.Lock() 318 defer bind.mu.Unlock() 319 bind.closeAndZero() 320 return nil 321 } 322 323 func (bind *WinRingBind) SetMark(mark uint32) error { 324 return nil 325 } 326 327 func (bind *afWinRingBind) InsertReceiveRequest() error { 328 packet := bind.rx.Push() 329 dataBuffer := &winrio.Buffer{ 330 Id: bind.rx.id, 331 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), 332 Length: uint32(len(packet.data)), 333 } 334 addressBuffer := &winrio.Buffer{ 335 Id: bind.rx.id, 336 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), 337 Length: uint32(unsafe.Sizeof(packet.addr)), 338 } 339 bind.mu.Lock() 340 defer bind.mu.Unlock() 341 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 342 } 343 344 //go:linkname procyield runtime.procyield 345 func procyield(cycles uint32) 346 347 func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { 348 if atomic.LoadUint32(isOpen) != 1 { 349 return 0, nil, net.ErrClosed 350 } 351 bind.rx.mu.Lock() 352 defer bind.rx.mu.Unlock() 353 354 var err error 355 var count uint32 356 var results [1]winrio.Result 357 retry: 358 count = 0 359 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 360 if tries > 0 { 361 if atomic.LoadUint32(isOpen) != 1 { 362 return 0, nil, net.ErrClosed 363 } 364 procyield(1) 365 } 366 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 367 } 368 if count == 0 { 369 err = winrio.Notify(bind.rx.cq) 370 if err != nil { 371 return 0, nil, err 372 } 373 var bytes uint32 374 var key uintptr 375 var overlapped *windows.Overlapped 376 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 377 if err != nil { 378 return 0, nil, err 379 } 380 if atomic.LoadUint32(isOpen) != 1 { 381 return 0, nil, net.ErrClosed 382 } 383 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 384 if count == 0 { 385 return 0, nil, io.ErrNoProgress 386 387 } 388 } 389 bind.rx.Return(1) 390 err = bind.InsertReceiveRequest() 391 if err != nil { 392 return 0, nil, err 393 } 394 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 395 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 396 // attacker bandwidth, just like the rest of the receive path. 397 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 398 if atomic.LoadUint32(isOpen) != 1 { 399 return 0, nil, net.ErrClosed 400 } 401 goto retry 402 } 403 if results[0].Status != 0 { 404 return 0, nil, windows.Errno(results[0].Status) 405 } 406 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 407 ep := packet.addr 408 n := copy(buf, packet.data[:results[0].BytesTransferred]) 409 return n, &ep, nil 410 } 411 412 func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { 413 bind.mu.RLock() 414 defer bind.mu.RUnlock() 415 return bind.v4.Receive(buf, &bind.isOpen) 416 } 417 418 func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { 419 bind.mu.RLock() 420 defer bind.mu.RUnlock() 421 return bind.v6.Receive(buf, &bind.isOpen) 422 } 423 424 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { 425 if atomic.LoadUint32(isOpen) != 1 { 426 return net.ErrClosed 427 } 428 if len(buf) > bytesPerPacket { 429 return io.ErrShortBuffer 430 } 431 bind.tx.mu.Lock() 432 defer bind.tx.mu.Unlock() 433 var results [packetsPerRing]winrio.Result 434 count := winrio.DequeueCompletion(bind.tx.cq, results[:]) 435 if count == 0 && bind.tx.isFull { 436 err := winrio.Notify(bind.tx.cq) 437 if err != nil { 438 return err 439 } 440 var bytes uint32 441 var key uintptr 442 var overlapped *windows.Overlapped 443 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 444 if err != nil { 445 return err 446 } 447 if atomic.LoadUint32(isOpen) != 1 { 448 return net.ErrClosed 449 } 450 count = winrio.DequeueCompletion(bind.tx.cq, results[:]) 451 if count == 0 { 452 return io.ErrNoProgress 453 } 454 } 455 if count > 0 { 456 bind.tx.Return(count) 457 } 458 packet := bind.tx.Push() 459 packet.addr = *nend 460 copy(packet.data[:], buf) 461 dataBuffer := &winrio.Buffer{ 462 Id: bind.tx.id, 463 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), 464 Length: uint32(len(buf)), 465 } 466 addressBuffer := &winrio.Buffer{ 467 Id: bind.tx.id, 468 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), 469 Length: uint32(unsafe.Sizeof(packet.addr)), 470 } 471 bind.mu.Lock() 472 defer bind.mu.Unlock() 473 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 474 } 475 476 func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { 477 nend, ok := endpoint.(*WinRingEndpoint) 478 if !ok { 479 return ErrWrongEndpointType 480 } 481 bind.mu.RLock() 482 defer bind.mu.RUnlock() 483 switch nend.family { 484 case windows.AF_INET: 485 if bind.v4.blackhole { 486 return nil 487 } 488 return bind.v4.Send(buf, nend, &bind.isOpen) 489 case windows.AF_INET6: 490 if bind.v6.blackhole { 491 return nil 492 } 493 return bind.v6.Send(buf, nend, &bind.isOpen) 494 } 495 return nil 496 } 497 498 func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 499 bind.mu.Lock() 500 defer bind.mu.Unlock() 501 sysconn, err := bind.ipv4.SyscallConn() 502 if err != nil { 503 return err 504 } 505 err2 := sysconn.Control(func(fd uintptr) { 506 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) 507 }) 508 if err2 != nil { 509 return err2 510 } 511 if err != nil { 512 return err 513 } 514 bind.blackhole4 = blackhole 515 return nil 516 } 517 518 func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 519 bind.mu.Lock() 520 defer bind.mu.Unlock() 521 sysconn, err := bind.ipv6.SyscallConn() 522 if err != nil { 523 return err 524 } 525 err2 := sysconn.Control(func(fd uintptr) { 526 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) 527 }) 528 if err2 != nil { 529 return err2 530 } 531 if err != nil { 532 return err 533 } 534 bind.blackhole6 = blackhole 535 return nil 536 } 537 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 538 bind.mu.RLock() 539 defer bind.mu.RUnlock() 540 if atomic.LoadUint32(&bind.isOpen) != 1 { 541 return net.ErrClosed 542 } 543 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) 544 if err != nil { 545 return err 546 } 547 bind.v4.blackhole = blackhole 548 return nil 549 } 550 551 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 552 bind.mu.RLock() 553 defer bind.mu.RUnlock() 554 if atomic.LoadUint32(&bind.isOpen) != 1 { 555 return net.ErrClosed 556 } 557 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) 558 if err != nil { 559 return err 560 } 561 bind.v6.blackhole = blackhole 562 return nil 563 } 564 565 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { 566 const IP_UNICAST_IF = 31 567 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ 568 var bytes [4]byte 569 binary.BigEndian.PutUint32(bytes[:], interfaceIndex) 570 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) 571 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) 572 if err != nil { 573 return err 574 } 575 return nil 576 } 577 578 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { 579 const IPV6_UNICAST_IF = 31 580 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) 581 }