github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/conn/bind_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "encoding/binary" 10 "io" 11 "net" 12 "strconv" 13 "sync" 14 "sync/atomic" 15 "unsafe" 16 17 "golang.org/x/sys/windows" 18 19 "github.com/tailscale/wireguard-go/conn/winrio" 20 ) 21 22 const ( 23 packetsPerRing = 1024 24 bytesPerPacket = 2048 - 32 25 receiveSpins = 15 26 ) 27 28 type ringPacket struct { 29 addr WinRingEndpoint 30 data [bytesPerPacket]byte 31 } 32 33 type ringBuffer struct { 34 packets uintptr 35 head, tail uint32 36 id winrio.BufferId 37 iocp windows.Handle 38 isFull bool 39 cq winrio.Cq 40 mu sync.Mutex 41 overlapped windows.Overlapped 42 } 43 44 func (rb *ringBuffer) Push() *ringPacket { 45 for rb.isFull { 46 panic("ring is full") 47 } 48 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 49 rb.tail += 1 50 if rb.tail%packetsPerRing == rb.head%packetsPerRing { 51 rb.isFull = true 52 } 53 return ret 54 } 55 56 func (rb *ringBuffer) Return(count uint32) { 57 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { 58 return 59 } 60 rb.head += count 61 rb.isFull = false 62 } 63 64 type afWinRingBind struct { 65 sock windows.Handle 66 rx, tx ringBuffer 67 rq winrio.Rq 68 mu sync.Mutex 69 blackhole bool 70 } 71 72 // WinRingBind uses Windows registered I/O for fast ring buffered networking. 73 type WinRingBind struct { 74 v4, v6 afWinRingBind 75 mu sync.RWMutex 76 isOpen uint32 77 } 78 79 func NewDefaultBind() Bind { return NewWinRingBind() } 80 81 func NewWinRingBind() Bind { 82 if !winrio.Initialize() { 83 return NewStdNetBind() 84 } 85 return new(WinRingBind) 86 } 87 88 type WinRingEndpoint struct { 89 family uint16 90 data [30]byte 91 } 92 93 var _ Bind = (*WinRingBind)(nil) 94 var _ Endpoint = (*WinRingEndpoint)(nil) 95 96 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { 97 host, port, err := net.SplitHostPort(s) 98 if err != nil { 99 return nil, err 100 } 101 host16, err := windows.UTF16PtrFromString(host) 102 if err != nil { 103 return nil, err 104 } 105 port16, err := windows.UTF16PtrFromString(port) 106 if err != nil { 107 return nil, err 108 } 109 hints := windows.AddrinfoW{ 110 Flags: windows.AI_NUMERICHOST, 111 Family: windows.AF_UNSPEC, 112 Socktype: windows.SOCK_DGRAM, 113 Protocol: windows.IPPROTO_UDP, 114 } 115 var addrinfo *windows.AddrinfoW 116 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) 117 if err != nil { 118 return nil, err 119 } 120 defer windows.FreeAddrInfoW(addrinfo) 121 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { 122 return nil, windows.ERROR_INVALID_ADDRESS 123 } 124 var src []byte 125 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte 126 unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen)) 127 copy(dst[:], src) 128 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil 129 } 130 131 func (*WinRingEndpoint) ClearSrc() {} 132 133 func (e *WinRingEndpoint) DstIP() net.IP { 134 switch e.family { 135 case windows.AF_INET: 136 return append([]byte{}, e.data[2:6]...) 137 case windows.AF_INET6: 138 return append([]byte{}, e.data[6:22]...) 139 } 140 return nil 141 } 142 143 func (e *WinRingEndpoint) SrcIP() net.IP { 144 return nil // not supported 145 } 146 147 func (e *WinRingEndpoint) DstToBytes() []byte { 148 switch e.family { 149 case windows.AF_INET: 150 b := make([]byte, 0, 6) 151 b = append(b, e.data[2:6]...) 152 b = append(b, e.data[1], e.data[0]) 153 return b 154 case windows.AF_INET6: 155 b := make([]byte, 0, 18) 156 b = append(b, e.data[6:22]...) 157 b = append(b, e.data[1], e.data[0]) 158 return b 159 } 160 return nil 161 } 162 163 func (e *WinRingEndpoint) DstToString() string { 164 switch e.family { 165 case windows.AF_INET: 166 addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} 167 return addr.String() 168 case windows.AF_INET6: 169 var zone string 170 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { 171 zone = strconv.FormatUint(uint64(scope), 10) 172 } 173 addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} 174 return addr.String() 175 } 176 return "" 177 } 178 179 func (e *WinRingEndpoint) SrcToString() string { 180 return "" 181 } 182 183 func (ring *ringBuffer) CloseAndZero() { 184 if ring.cq != 0 { 185 winrio.CloseCompletionQueue(ring.cq) 186 ring.cq = 0 187 } 188 if ring.iocp != 0 { 189 windows.CloseHandle(ring.iocp) 190 ring.iocp = 0 191 } 192 if ring.id != 0 { 193 winrio.DeregisterBuffer(ring.id) 194 ring.id = 0 195 } 196 if ring.packets != 0 { 197 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 198 ring.packets = 0 199 } 200 ring.head = 0 201 ring.tail = 0 202 ring.isFull = false 203 } 204 205 func (bind *afWinRingBind) CloseAndZero() { 206 bind.rx.CloseAndZero() 207 bind.tx.CloseAndZero() 208 if bind.sock != 0 { 209 windows.CloseHandle(bind.sock) 210 bind.sock = 0 211 } 212 bind.blackhole = false 213 } 214 215 func (bind *WinRingBind) closeAndZero() { 216 atomic.StoreUint32(&bind.isOpen, 0) 217 bind.v4.CloseAndZero() 218 bind.v6.CloseAndZero() 219 } 220 221 func (ring *ringBuffer) Open() error { 222 var err error 223 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 224 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 225 if err != nil { 226 return err 227 } 228 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 229 if err != nil { 230 return err 231 } 232 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 233 if err != nil { 234 return err 235 } 236 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 237 if err != nil { 238 return err 239 } 240 return nil 241 } 242 243 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { 244 var err error 245 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 246 if err != nil { 247 return nil, err 248 } 249 err = bind.rx.Open() 250 if err != nil { 251 return nil, err 252 } 253 err = bind.tx.Open() 254 if err != nil { 255 return nil, err 256 } 257 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) 258 if err != nil { 259 return nil, err 260 } 261 err = windows.Bind(bind.sock, sa) 262 if err != nil { 263 return nil, err 264 } 265 sa, err = windows.Getsockname(bind.sock) 266 if err != nil { 267 return nil, err 268 } 269 return sa, nil 270 } 271 272 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { 273 bind.mu.Lock() 274 defer bind.mu.Unlock() 275 defer func() { 276 if err != nil { 277 bind.closeAndZero() 278 } 279 }() 280 if atomic.LoadUint32(&bind.isOpen) != 0 { 281 return nil, 0, ErrBindAlreadyOpen 282 } 283 var sa windows.Sockaddr 284 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) 285 if err != nil { 286 return nil, 0, err 287 } 288 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) 289 if err != nil { 290 return nil, 0, err 291 } 292 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) 293 for i := 0; i < packetsPerRing; i++ { 294 err = bind.v4.InsertReceiveRequest() 295 if err != nil { 296 return nil, 0, err 297 } 298 err = bind.v6.InsertReceiveRequest() 299 if err != nil { 300 return nil, 0, err 301 } 302 } 303 atomic.StoreUint32(&bind.isOpen, 1) 304 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err 305 } 306 307 func (bind *WinRingBind) Close() error { 308 bind.mu.RLock() 309 if atomic.LoadUint32(&bind.isOpen) != 1 { 310 bind.mu.RUnlock() 311 return nil 312 } 313 atomic.StoreUint32(&bind.isOpen, 2) 314 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) 315 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) 316 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) 317 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) 318 bind.mu.RUnlock() 319 bind.mu.Lock() 320 defer bind.mu.Unlock() 321 bind.closeAndZero() 322 return nil 323 } 324 325 func (bind *WinRingBind) SetMark(mark uint32) error { 326 return nil 327 } 328 329 func (bind *afWinRingBind) InsertReceiveRequest() error { 330 packet := bind.rx.Push() 331 dataBuffer := &winrio.Buffer{ 332 Id: bind.rx.id, 333 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), 334 Length: uint32(len(packet.data)), 335 } 336 addressBuffer := &winrio.Buffer{ 337 Id: bind.rx.id, 338 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), 339 Length: uint32(unsafe.Sizeof(packet.addr)), 340 } 341 bind.mu.Lock() 342 defer bind.mu.Unlock() 343 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 344 } 345 346 //go:linkname procyield runtime.procyield 347 func procyield(cycles uint32) 348 349 func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { 350 if atomic.LoadUint32(isOpen) != 1 { 351 return 0, nil, net.ErrClosed 352 } 353 bind.rx.mu.Lock() 354 defer bind.rx.mu.Unlock() 355 356 var err error 357 var count uint32 358 var results [1]winrio.Result 359 retry: 360 count = 0 361 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 362 if tries > 0 { 363 if atomic.LoadUint32(isOpen) != 1 { 364 return 0, nil, net.ErrClosed 365 } 366 procyield(1) 367 } 368 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 369 } 370 if count == 0 { 371 err = winrio.Notify(bind.rx.cq) 372 if err != nil { 373 return 0, nil, err 374 } 375 var bytes uint32 376 var key uintptr 377 var overlapped *windows.Overlapped 378 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 379 if err != nil { 380 return 0, nil, err 381 } 382 if atomic.LoadUint32(isOpen) != 1 { 383 return 0, nil, net.ErrClosed 384 } 385 count = winrio.DequeueCompletion(bind.rx.cq, results[:]) 386 if count == 0 { 387 return 0, nil, io.ErrNoProgress 388 389 } 390 } 391 bind.rx.Return(1) 392 err = bind.InsertReceiveRequest() 393 if err != nil { 394 return 0, nil, err 395 } 396 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 397 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 398 // attacker bandwidth, just like the rest of the receive path. 399 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 400 if atomic.LoadUint32(isOpen) != 1 { 401 return 0, nil, net.ErrClosed 402 } 403 goto retry 404 } 405 if results[0].Status != 0 { 406 return 0, nil, windows.Errno(results[0].Status) 407 } 408 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 409 ep := packet.addr 410 n := copy(buf, packet.data[:results[0].BytesTransferred]) 411 return n, &ep, nil 412 } 413 414 func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { 415 bind.mu.RLock() 416 defer bind.mu.RUnlock() 417 return bind.v4.Receive(buf, &bind.isOpen) 418 } 419 420 func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { 421 bind.mu.RLock() 422 defer bind.mu.RUnlock() 423 return bind.v6.Receive(buf, &bind.isOpen) 424 } 425 426 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { 427 if atomic.LoadUint32(isOpen) != 1 { 428 return net.ErrClosed 429 } 430 if len(buf) > bytesPerPacket { 431 return io.ErrShortBuffer 432 } 433 bind.tx.mu.Lock() 434 defer bind.tx.mu.Unlock() 435 var results [packetsPerRing]winrio.Result 436 count := winrio.DequeueCompletion(bind.tx.cq, results[:]) 437 if count == 0 && bind.tx.isFull { 438 err := winrio.Notify(bind.tx.cq) 439 if err != nil { 440 return err 441 } 442 var bytes uint32 443 var key uintptr 444 var overlapped *windows.Overlapped 445 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 446 if err != nil { 447 return err 448 } 449 if atomic.LoadUint32(isOpen) != 1 { 450 return net.ErrClosed 451 } 452 count = winrio.DequeueCompletion(bind.tx.cq, results[:]) 453 if count == 0 { 454 return io.ErrNoProgress 455 } 456 } 457 if count > 0 { 458 bind.tx.Return(count) 459 } 460 packet := bind.tx.Push() 461 packet.addr = *nend 462 copy(packet.data[:], buf) 463 dataBuffer := &winrio.Buffer{ 464 Id: bind.tx.id, 465 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), 466 Length: uint32(len(buf)), 467 } 468 addressBuffer := &winrio.Buffer{ 469 Id: bind.tx.id, 470 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), 471 Length: uint32(unsafe.Sizeof(packet.addr)), 472 } 473 bind.mu.Lock() 474 defer bind.mu.Unlock() 475 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 476 } 477 478 func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { 479 nend, ok := endpoint.(*WinRingEndpoint) 480 if !ok { 481 return ErrWrongEndpointType 482 } 483 bind.mu.RLock() 484 defer bind.mu.RUnlock() 485 switch nend.family { 486 case windows.AF_INET: 487 if bind.v4.blackhole { 488 return nil 489 } 490 return bind.v4.Send(buf, nend, &bind.isOpen) 491 case windows.AF_INET6: 492 if bind.v6.blackhole { 493 return nil 494 } 495 return bind.v6.Send(buf, nend, &bind.isOpen) 496 } 497 return nil 498 } 499 500 func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 501 bind.mu.Lock() 502 defer bind.mu.Unlock() 503 sysconn, err := bind.ipv4.SyscallConn() 504 if err != nil { 505 return err 506 } 507 err2 := sysconn.Control(func(fd uintptr) { 508 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) 509 }) 510 if err2 != nil { 511 return err2 512 } 513 if err != nil { 514 return err 515 } 516 bind.blackhole4 = blackhole 517 return nil 518 } 519 520 func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 521 bind.mu.Lock() 522 defer bind.mu.Unlock() 523 sysconn, err := bind.ipv6.SyscallConn() 524 if err != nil { 525 return err 526 } 527 err2 := sysconn.Control(func(fd uintptr) { 528 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) 529 }) 530 if err2 != nil { 531 return err2 532 } 533 if err != nil { 534 return err 535 } 536 bind.blackhole6 = blackhole 537 return nil 538 } 539 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { 540 bind.mu.RLock() 541 defer bind.mu.RUnlock() 542 if atomic.LoadUint32(&bind.isOpen) != 1 { 543 return net.ErrClosed 544 } 545 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) 546 if err != nil { 547 return err 548 } 549 bind.v4.blackhole = blackhole 550 return nil 551 } 552 553 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { 554 bind.mu.RLock() 555 defer bind.mu.RUnlock() 556 if atomic.LoadUint32(&bind.isOpen) != 1 { 557 return net.ErrClosed 558 } 559 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) 560 if err != nil { 561 return err 562 } 563 bind.v6.blackhole = blackhole 564 return nil 565 } 566 567 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { 568 const IP_UNICAST_IF = 31 569 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ 570 var bytes [4]byte 571 binary.BigEndian.PutUint32(bytes[:], interfaceIndex) 572 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) 573 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) 574 if err != nil { 575 return err 576 } 577 return nil 578 } 579 580 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { 581 const IPV6_UNICAST_IF = 31 582 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) 583 } 584 585 // unsafeSlice updates the slice slicePtr to be a slice 586 // referencing the provided data with its length & capacity set to 587 // lenCap. 588 // 589 // TODO: when Go 1.16 or Go 1.17 is the minimum supported version, 590 // update callers to use unsafe.Slice instead of this. 591 func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { 592 type sliceHeader struct { 593 Data unsafe.Pointer 594 Len int 595 Cap int 596 } 597 h := (*sliceHeader)(slicePtr) 598 h.Data = data 599 h.Len = lenCap 600 h.Cap = lenCap 601 }