github.com/slackhq/nebula@v1.9.0/udp/udp_rio_windows.go (about) 1 //go:build !e2e_testing 2 // +build !e2e_testing 3 4 // Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go 5 6 package udp 7 8 import ( 9 "errors" 10 "fmt" 11 "io" 12 "net" 13 "sync" 14 "sync/atomic" 15 "syscall" 16 "unsafe" 17 18 "github.com/sirupsen/logrus" 19 "github.com/slackhq/nebula/config" 20 "github.com/slackhq/nebula/firewall" 21 "github.com/slackhq/nebula/header" 22 23 "golang.org/x/sys/windows" 24 "golang.zx2c4.com/wireguard/conn/winrio" 25 ) 26 27 // Assert we meet the standard conn interface 28 var _ Conn = &RIOConn{} 29 30 //go:linkname procyield runtime.procyield 31 func procyield(cycles uint32) 32 33 const ( 34 packetsPerRing = 1024 35 bytesPerPacket = 2048 - 32 36 receiveSpins = 15 37 ) 38 39 type ringPacket struct { 40 addr windows.RawSockaddrInet6 41 data [bytesPerPacket]byte 42 } 43 44 type ringBuffer struct { 45 packets uintptr 46 head, tail uint32 47 id winrio.BufferId 48 iocp windows.Handle 49 isFull bool 50 cq winrio.Cq 51 mu sync.Mutex 52 overlapped windows.Overlapped 53 } 54 55 type RIOConn struct { 56 isOpen atomic.Bool 57 l *logrus.Logger 58 sock windows.Handle 59 rx, tx ringBuffer 60 rq winrio.Rq 61 results [packetsPerRing]winrio.Result 62 } 63 64 func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { 65 if !winrio.Initialize() { 66 return nil, errors.New("could not initialize winrio") 67 } 68 69 u := &RIOConn{l: l} 70 71 addr := [16]byte{} 72 copy(addr[:], ip.To16()) 73 err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) 74 if err != nil { 75 return nil, fmt.Errorf("bind: %w", err) 76 } 77 78 for i := 0; i < packetsPerRing; i++ { 79 err = u.insertReceiveRequest() 80 if err != nil { 81 return nil, fmt.Errorf("init rx ring: %w", err) 82 } 83 } 84 85 u.isOpen.Store(true) 86 return u, nil 87 } 88 89 func (u *RIOConn) bind(sa windows.Sockaddr) error { 90 var err error 91 u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 92 if err != nil { 93 return err 94 } 95 96 // Enable v4 for this socket 97 syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) 98 99 err = u.rx.Open() 100 if err != nil { 101 return err 102 } 103 104 err = u.tx.Open() 105 if err != nil { 106 return err 107 } 108 109 u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0) 110 if err != nil { 111 return err 112 } 113 114 err = windows.Bind(u.sock, sa) 115 if err != nil { 116 return err 117 } 118 119 return nil 120 } 121 122 func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { 123 plaintext := make([]byte, MTU) 124 buffer := make([]byte, MTU) 125 h := &header.H{} 126 fwPacket := &firewall.Packet{} 127 udpAddr := &Addr{IP: make([]byte, 16)} 128 nb := make([]byte, 12, 12) 129 130 for { 131 // Just read one packet at a time 132 n, rua, err := u.receive(buffer) 133 if err != nil { 134 u.l.WithError(err).Debug("udp socket is closed, exiting read loop") 135 return 136 } 137 138 udpAddr.IP = rua.Addr[:] 139 p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) 140 p[0] = byte(rua.Port >> 8) 141 p[1] = byte(rua.Port) 142 r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) 143 } 144 } 145 146 func (u *RIOConn) insertReceiveRequest() error { 147 packet := u.rx.Push() 148 dataBuffer := &winrio.Buffer{ 149 Id: u.rx.id, 150 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets), 151 Length: uint32(len(packet.data)), 152 } 153 addressBuffer := &winrio.Buffer{ 154 Id: u.rx.id, 155 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets), 156 Length: uint32(unsafe.Sizeof(packet.addr)), 157 } 158 159 return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) 160 } 161 162 func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { 163 if !u.isOpen.Load() { 164 return 0, windows.RawSockaddrInet6{}, net.ErrClosed 165 } 166 167 u.rx.mu.Lock() 168 defer u.rx.mu.Unlock() 169 170 var err error 171 var count uint32 172 var results [1]winrio.Result 173 174 retry: 175 count = 0 176 for tries := 0; count == 0 && tries < receiveSpins; tries++ { 177 if tries > 0 { 178 if !u.isOpen.Load() { 179 return 0, windows.RawSockaddrInet6{}, net.ErrClosed 180 } 181 procyield(1) 182 } 183 184 count = winrio.DequeueCompletion(u.rx.cq, results[:]) 185 } 186 187 if count == 0 { 188 err = winrio.Notify(u.rx.cq) 189 if err != nil { 190 return 0, windows.RawSockaddrInet6{}, err 191 } 192 var bytes uint32 193 var key uintptr 194 var overlapped *windows.Overlapped 195 err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 196 if err != nil { 197 return 0, windows.RawSockaddrInet6{}, err 198 } 199 200 if !u.isOpen.Load() { 201 return 0, windows.RawSockaddrInet6{}, net.ErrClosed 202 } 203 204 count = winrio.DequeueCompletion(u.rx.cq, results[:]) 205 if count == 0 { 206 return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress 207 208 } 209 } 210 211 u.rx.Return(1) 212 err = u.insertReceiveRequest() 213 if err != nil { 214 return 0, windows.RawSockaddrInet6{}, err 215 } 216 217 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us 218 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to 219 // attacker bandwidth, just like the rest of the receive path. 220 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { 221 goto retry 222 } 223 224 if results[0].Status != 0 { 225 return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status) 226 } 227 228 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) 229 ep := packet.addr 230 n := copy(buf, packet.data[:results[0].BytesTransferred]) 231 return n, ep, nil 232 } 233 234 func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { 235 if !u.isOpen.Load() { 236 return net.ErrClosed 237 } 238 239 if len(buf) > bytesPerPacket { 240 return io.ErrShortBuffer 241 } 242 243 u.tx.mu.Lock() 244 defer u.tx.mu.Unlock() 245 246 count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) 247 if count == 0 && u.tx.isFull { 248 err := winrio.Notify(u.tx.cq) 249 if err != nil { 250 return err 251 } 252 253 var bytes uint32 254 var key uintptr 255 var overlapped *windows.Overlapped 256 err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) 257 if err != nil { 258 return err 259 } 260 261 if !u.isOpen.Load() { 262 return net.ErrClosed 263 } 264 265 count = winrio.DequeueCompletion(u.tx.cq, u.results[:]) 266 if count == 0 { 267 return io.ErrNoProgress 268 } 269 } 270 271 if count > 0 { 272 u.tx.Return(count) 273 } 274 275 packet := u.tx.Push() 276 packet.addr.Family = windows.AF_INET6 277 p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) 278 p[0] = byte(addr.Port >> 8) 279 p[1] = byte(addr.Port) 280 copy(packet.addr.Addr[:], addr.IP.To16()) 281 copy(packet.data[:], buf) 282 283 dataBuffer := &winrio.Buffer{ 284 Id: u.tx.id, 285 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets), 286 Length: uint32(len(buf)), 287 } 288 289 addressBuffer := &winrio.Buffer{ 290 Id: u.tx.id, 291 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets), 292 Length: uint32(unsafe.Sizeof(packet.addr)), 293 } 294 295 return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) 296 } 297 298 func (u *RIOConn) LocalAddr() (*Addr, error) { 299 sa, err := windows.Getsockname(u.sock) 300 if err != nil { 301 return nil, err 302 } 303 304 v6 := sa.(*windows.SockaddrInet6) 305 return &Addr{ 306 IP: v6.Addr[:], 307 Port: uint16(v6.Port), 308 }, nil 309 } 310 311 func (u *RIOConn) Rebind() error { 312 return nil 313 } 314 315 func (u *RIOConn) ReloadConfig(*config.C) {} 316 317 func (u *RIOConn) Close() error { 318 if !u.isOpen.CompareAndSwap(true, false) { 319 return nil 320 } 321 322 windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) 323 windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) 324 325 u.rx.CloseAndZero() 326 u.tx.CloseAndZero() 327 if u.sock != 0 { 328 windows.CloseHandle(u.sock) 329 } 330 return nil 331 } 332 333 func (ring *ringBuffer) Push() *ringPacket { 334 for ring.isFull { 335 panic("ring is full") 336 } 337 ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) 338 ring.tail += 1 339 if ring.tail%packetsPerRing == ring.head%packetsPerRing { 340 ring.isFull = true 341 } 342 return ret 343 } 344 345 func (ring *ringBuffer) Return(count uint32) { 346 if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull { 347 return 348 } 349 ring.head += count 350 ring.isFull = false 351 } 352 353 func (ring *ringBuffer) CloseAndZero() { 354 if ring.cq != 0 { 355 winrio.CloseCompletionQueue(ring.cq) 356 ring.cq = 0 357 } 358 359 if ring.iocp != 0 { 360 windows.CloseHandle(ring.iocp) 361 ring.iocp = 0 362 } 363 364 if ring.id != 0 { 365 winrio.DeregisterBuffer(ring.id) 366 ring.id = 0 367 } 368 369 if ring.packets != 0 { 370 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) 371 ring.packets = 0 372 } 373 374 ring.head = 0 375 ring.tail = 0 376 ring.isFull = false 377 } 378 379 func (ring *ringBuffer) Open() error { 380 var err error 381 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing 382 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) 383 if err != nil { 384 return err 385 } 386 387 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) 388 if err != nil { 389 return err 390 } 391 392 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 393 if err != nil { 394 return err 395 } 396 397 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) 398 if err != nil { 399 return err 400 } 401 402 return nil 403 }