github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/conn/bind_std.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package conn 33 34 import ( 35 "context" 36 "errors" 37 "fmt" 38 "net" 39 "net/netip" 40 "runtime" 41 "strconv" 42 "sync" 43 "syscall" 44 45 "golang.org/x/net/ipv4" 46 "golang.org/x/net/ipv6" 47 ) 48 49 var ( 50 _ Bind = (*StdNetBind)(nil) 51 ) 52 53 // StdNetBind implements Bind for all platforms. While Windows has its own Bind 54 // (see bind_windows.go), it may fall back to StdNetBind. 55 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable 56 // methods for sending and receiving multiple datagrams per-syscall. See the 57 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. 58 type StdNetBind struct { 59 mu sync.Mutex // protects all fields except as specified 60 ipv4 *net.UDPConn 61 ipv6 *net.UDPConn 62 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 63 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 64 ipv4TxOffload bool 65 ipv4RxOffload bool 66 ipv6TxOffload bool 67 ipv6RxOffload bool 68 69 // these two fields are not guarded by mu 70 udpAddrPool sync.Pool 71 msgsPool sync.Pool 72 73 blackhole4 bool 74 blackhole6 bool 75 } 76 77 func NewStdNetBind() Bind { 78 return &StdNetBind{ 79 udpAddrPool: sync.Pool{ 80 New: func() any { 81 return &net.UDPAddr{ 82 IP: make([]byte, 16), 83 } 84 }, 85 }, 86 87 msgsPool: sync.Pool{ 88 New: func() any { 89 // ipv6.Message and ipv4.Message are interchangeable as they are 90 // both aliases for x/net/internal/socket.Message. 91 msgs := make([]ipv6.Message, IdealBatchSize) 92 for i := range msgs { 93 msgs[i].Buffers = make(net.Buffers, 1) 94 msgs[i].OOB = make([]byte, 0, gsoControlSize) 95 } 96 return &msgs 97 }, 98 }, 99 } 100 } 101 102 type StdNetEndpoint struct { 103 // AddrPort is the endpoint destination. 104 netip.AddrPort 105 } 106 107 var ( 108 _ Bind = (*StdNetBind)(nil) 109 _ Endpoint = &StdNetEndpoint{} 110 ) 111 112 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 113 e, err := netip.ParseAddrPort(s) 114 if err != nil { 115 return nil, err 116 } 117 return &StdNetEndpoint{ 118 AddrPort: e, 119 }, nil 120 } 121 122 func (e *StdNetEndpoint) DstIP() netip.Addr { 123 return e.AddrPort.Addr() 124 } 125 126 // See control_default,linux, etc for implementations of SrcIP and SrcIfidx. 127 128 func (e *StdNetEndpoint) DstToBytes() []byte { 129 b, _ := e.AddrPort.MarshalBinary() 130 return b 131 } 132 133 func (e *StdNetEndpoint) DstToString() string { 134 return e.AddrPort.String() 135 } 136 137 func listenNet(network string, port int) (*net.UDPConn, int, error) { 138 conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) 139 if err != nil { 140 return nil, 0, err 141 } 142 143 // Retrieve port. 144 laddr := conn.LocalAddr() 145 uaddr, err := net.ResolveUDPAddr( 146 laddr.Network(), 147 laddr.String(), 148 ) 149 if err != nil { 150 return nil, 0, err 151 } 152 return conn.(*net.UDPConn), uaddr.Port, nil 153 } 154 155 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 156 s.mu.Lock() 157 defer s.mu.Unlock() 158 159 var err error 160 var tries int 161 162 if s.ipv4 != nil || s.ipv6 != nil { 163 return nil, 0, ErrBindAlreadyOpen 164 } 165 166 // Attempt to open ipv4 and ipv6 listeners on the same port. 167 // If uport is 0, we can retry on failure. 168 again: 169 port := int(uport) 170 var v4conn, v6conn *net.UDPConn 171 var v4pc *ipv4.PacketConn 172 var v6pc *ipv6.PacketConn 173 174 v4conn, port, err = listenNet("udp4", port) 175 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 176 return nil, 0, err 177 } 178 179 // Listen on the same port as we're using for ipv4. 180 v6conn, port, err = listenNet("udp6", port) 181 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 182 v4conn.Close() 183 tries++ 184 goto again 185 } 186 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 187 v4conn.Close() 188 return nil, 0, err 189 } 190 var fns []ReceiveFunc 191 if v4conn != nil { 192 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) 193 if runtime.GOOS == "linux" { 194 v4pc = ipv4.NewPacketConn(v4conn) 195 s.ipv4PC = v4pc 196 } 197 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) 198 s.ipv4 = v4conn 199 } 200 if v6conn != nil { 201 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) 202 if runtime.GOOS == "linux" { 203 v6pc = ipv6.NewPacketConn(v6conn) 204 s.ipv6PC = v6pc 205 } 206 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) 207 s.ipv6 = v6conn 208 } 209 if len(fns) == 0 { 210 return nil, 0, syscall.EAFNOSUPPORT 211 } 212 213 return fns, uint16(port), nil 214 } 215 216 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { 217 for i := range *msgs { 218 (*msgs)[i].OOB = (*msgs)[i].OOB[:0] 219 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} 220 } 221 s.msgsPool.Put(msgs) 222 } 223 224 func (s *StdNetBind) getMessages() *[]ipv6.Message { 225 return s.msgsPool.Get().(*[]ipv6.Message) 226 } 227 228 var ( 229 // If compilation fails here these are no longer the same underlying type. 230 _ ipv6.Message = ipv4.Message{} 231 ) 232 233 type batchReader interface { 234 ReadBatch([]ipv6.Message, int) (int, error) 235 } 236 237 type batchWriter interface { 238 WriteBatch([]ipv6.Message, int) (int, error) 239 } 240 241 func (s *StdNetBind) receiveIP( 242 br batchReader, 243 conn *net.UDPConn, 244 rxOffload bool, 245 bufs [][]byte, 246 sizes []int, 247 eps []Endpoint, 248 ) (n int, err error) { 249 msgs := s.getMessages() 250 for i := range bufs { 251 (*msgs)[i].Buffers[0] = bufs[i] 252 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] 253 } 254 defer s.putMessages(msgs) 255 var numMsgs int 256 if runtime.GOOS == "linux" { 257 if rxOffload { 258 readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) 259 _, err = br.ReadBatch((*msgs)[readAt:], 0) 260 if err != nil { 261 return 0, err 262 } 263 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) 264 if err != nil { 265 return 0, err 266 } 267 } else { 268 numMsgs, err = br.ReadBatch(*msgs, 0) 269 if err != nil { 270 return 0, err 271 } 272 } 273 } else { 274 msg := &(*msgs)[0] 275 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 276 if err != nil { 277 return 0, err 278 } 279 numMsgs = 1 280 } 281 for i := 0; i < numMsgs; i++ { 282 msg := &(*msgs)[i] 283 sizes[i] = msg.N 284 if sizes[i] == 0 { 285 continue 286 } 287 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 288 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 289 eps[i] = ep 290 } 291 return numMsgs, nil 292 } 293 294 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 295 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 296 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 297 } 298 } 299 300 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 301 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 302 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 303 } 304 } 305 306 func (s *StdNetBind) BatchSize() int { 307 if runtime.GOOS == "linux" { 308 return IdealBatchSize 309 } 310 return 1 311 } 312 313 func (s *StdNetBind) Close() error { 314 s.mu.Lock() 315 defer s.mu.Unlock() 316 317 var err1, err2 error 318 if s.ipv4 != nil { 319 err1 = s.ipv4.Close() 320 s.ipv4 = nil 321 s.ipv4PC = nil 322 } 323 if s.ipv6 != nil { 324 err2 = s.ipv6.Close() 325 s.ipv6 = nil 326 s.ipv6PC = nil 327 } 328 s.blackhole4 = false 329 s.blackhole6 = false 330 s.ipv4TxOffload = false 331 s.ipv4RxOffload = false 332 s.ipv6TxOffload = false 333 s.ipv6RxOffload = false 334 if err1 != nil { 335 return err1 336 } 337 return err2 338 } 339 340 type ErrUDPGSODisabled struct { 341 onLaddr string 342 RetryErr error 343 } 344 345 func (e ErrUDPGSODisabled) Error() string { 346 return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) 347 } 348 349 func (e ErrUDPGSODisabled) Unwrap() error { 350 return e.RetryErr 351 } 352 353 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { 354 s.mu.Lock() 355 blackhole := s.blackhole4 356 conn := s.ipv4 357 offload := s.ipv4TxOffload 358 br := batchWriter(s.ipv4PC) 359 is6 := false 360 if endpoint.DstIP().Is6() { 361 blackhole = s.blackhole6 362 conn = s.ipv6 363 br = s.ipv6PC 364 is6 = true 365 offload = s.ipv6TxOffload 366 } 367 s.mu.Unlock() 368 369 if blackhole { 370 return nil 371 } 372 if conn == nil { 373 return syscall.EAFNOSUPPORT 374 } 375 376 msgs := s.getMessages() 377 defer s.putMessages(msgs) 378 ua := s.udpAddrPool.Get().(*net.UDPAddr) 379 defer s.udpAddrPool.Put(ua) 380 if is6 { 381 as16 := endpoint.DstIP().As16() 382 copy(ua.IP, as16[:]) 383 ua.IP = ua.IP[:16] 384 } else { 385 as4 := endpoint.DstIP().As4() 386 copy(ua.IP, as4[:]) 387 ua.IP = ua.IP[:4] 388 } 389 ua.Port = int(endpoint.(*StdNetEndpoint).Port()) 390 var ( 391 retried bool 392 err error 393 ) 394 retry: 395 if offload { 396 n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) 397 err = s.send(conn, br, (*msgs)[:n]) 398 if err != nil && offload && errShouldDisableUDPGSO(err) { 399 offload = false 400 s.mu.Lock() 401 if is6 { 402 s.ipv6TxOffload = false 403 } else { 404 s.ipv4TxOffload = false 405 } 406 s.mu.Unlock() 407 retried = true 408 goto retry 409 } 410 } else { 411 for i := range bufs { 412 (*msgs)[i].Addr = ua 413 (*msgs)[i].Buffers[0] = bufs[i] 414 } 415 err = s.send(conn, br, (*msgs)[:len(bufs)]) 416 } 417 if retried { 418 return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} 419 } 420 return err 421 } 422 423 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { 424 var ( 425 n int 426 err error 427 start int 428 ) 429 if runtime.GOOS == "linux" { 430 for { 431 n, err = pc.WriteBatch(msgs[start:], 0) 432 if err != nil || n == len(msgs[start:]) { 433 break 434 } 435 start += n 436 } 437 } else { 438 for _, msg := range msgs { 439 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) 440 if err != nil { 441 break 442 } 443 } 444 } 445 return err 446 } 447 448 const ( 449 // Exceeding these values results in EMSGSIZE. They account for layer3 and 450 // layer4 headers. IPv6 does not need to account for itself as the payload 451 // length field is self excluding. 452 maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 453 maxIPv6PayloadLen = 1<<16 - 1 - 8 454 455 // This is a hard limit imposed by the kernel. 456 udpSegmentMaxDatagrams = 64 457 ) 458 459 type setGSOFunc func(control *[]byte, gsoSize uint16) 460 461 func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { 462 var ( 463 base = -1 // index of msg we are currently coalescing into 464 gsoSize int // segmentation size of msgs[base] 465 dgramCnt int // number of dgrams coalesced into msgs[base] 466 endBatch bool // tracking flag to start a new batch on next iteration of bufs 467 ) 468 maxPayloadLen := maxIPv4PayloadLen 469 if ep.DstIP().Is6() { 470 maxPayloadLen = maxIPv6PayloadLen 471 } 472 for i, buf := range bufs { 473 if i > 0 { 474 msgLen := len(buf) 475 baseLenBefore := len(msgs[base].Buffers[0]) 476 freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore 477 if msgLen+baseLenBefore <= maxPayloadLen && 478 msgLen <= gsoSize && 479 msgLen <= freeBaseCap && 480 dgramCnt < udpSegmentMaxDatagrams && 481 !endBatch { 482 msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) 483 if i == len(bufs)-1 { 484 setGSO(&msgs[base].OOB, uint16(gsoSize)) 485 } 486 dgramCnt++ 487 if msgLen < gsoSize { 488 // A smaller than gsoSize packet on the tail is legal, but 489 // it must end the batch. 490 endBatch = true 491 } 492 continue 493 } 494 } 495 if dgramCnt > 1 { 496 setGSO(&msgs[base].OOB, uint16(gsoSize)) 497 } 498 // Reset prior to incrementing base since we are preparing to start a 499 // new potential batch. 500 endBatch = false 501 base++ 502 gsoSize = len(buf) 503 msgs[base].Buffers[0] = buf 504 msgs[base].Addr = addr 505 dgramCnt = 1 506 } 507 return base + 1 508 } 509 510 type getGSOFunc func(control []byte) (int, error) 511 512 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { 513 for i := firstMsgAt; i < len(msgs); i++ { 514 msg := &msgs[i] 515 if msg.N == 0 { 516 return n, err 517 } 518 var ( 519 gsoSize int 520 start int 521 end = msg.N 522 numToSplit = 1 523 ) 524 gsoSize, err = getGSO(msg.OOB[:msg.NN]) 525 if err != nil { 526 return n, err 527 } 528 if gsoSize > 0 { 529 numToSplit = (msg.N + gsoSize - 1) / gsoSize 530 end = gsoSize 531 } 532 for j := 0; j < numToSplit; j++ { 533 if n > i { 534 return n, errors.New("splitting coalesced packet resulted in overflow") 535 } 536 copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) 537 msgs[n].N = copied 538 msgs[n].Addr = msg.Addr 539 start = end 540 end += gsoSize 541 if end > msg.N { 542 end = msg.N 543 } 544 n++ 545 } 546 if i != n-1 { 547 // It is legal for bytes to move within msg.Buffers[0] as a result 548 // of splitting, so we only zero the source msg len when it is not 549 // the destination of the last split operation above. 550 msg.N = 0 551 } 552 } 553 return n, nil 554 }