github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "fmt" 11 "net" 12 "net/netip" 13 "runtime" 14 "strconv" 15 "sync" 16 "syscall" 17 18 "github.com/sagernet/sing/common" 19 M "github.com/sagernet/sing/common/metadata" 20 "golang.org/x/net/ipv4" 21 "golang.org/x/net/ipv6" 22 ) 23 24 var _ Bind = (*StdNetBind)(nil) 25 26 // StdNetBind implements Bind for all platforms. While Windows has its own Bind 27 // (see bind_windows.go), it may fall back to StdNetBind. 28 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable 29 // methods for sending and receiving multiple datagrams per-syscall. See the 30 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. 31 type StdNetBind struct { 32 listener Listener 33 reservedForEndpoint map[netip.AddrPort][3]uint8 34 35 mu sync.Mutex // protects all fields except as specified 36 ipv4 *net.UDPConn 37 ipv6 *net.UDPConn 38 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 39 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 40 ipv4TxOffload bool 41 ipv4RxOffload bool 42 ipv6TxOffload bool 43 ipv6RxOffload bool 44 45 // these two fields are not guarded by mu 46 udpAddrPool sync.Pool 47 msgsPool sync.Pool 48 49 blackhole4 bool 50 blackhole6 bool 51 } 52 53 func NewStdNetBind(listener Listener) Bind { 54 return &StdNetBind{ 55 listener: listener, 56 57 udpAddrPool: sync.Pool{ 58 New: func() any { 59 return &net.UDPAddr{ 60 IP: make([]byte, 16), 61 } 62 }, 63 }, 64 65 msgsPool: sync.Pool{ 66 New: func() any { 67 // ipv6.Message and ipv4.Message are interchangeable as they are 68 // both aliases for x/net/internal/socket.Message. 69 msgs := make([]ipv6.Message, IdealBatchSize) 70 for i := range msgs { 71 msgs[i].Buffers = make(net.Buffers, 1) 72 msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) 73 } 74 return &msgs 75 }, 76 }, 77 } 78 } 79 80 type StdNetEndpoint struct { 81 // AddrPort is the endpoint destination. 82 netip.AddrPort 83 // src is the current sticky source address and interface index, if 84 // supported. Typically this is a PKTINFO structure from/for control 85 // messages, see unix.PKTINFO for an example. 86 src []byte 87 } 88 89 var ( 90 _ Bind = (*StdNetBind)(nil) 91 _ Endpoint = &StdNetEndpoint{} 92 ) 93 94 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 95 e, err := netip.ParseAddrPort(s) 96 if err != nil { 97 return nil, err 98 } 99 return &StdNetEndpoint{ 100 AddrPort: e, 101 }, nil 102 } 103 104 func (e *StdNetEndpoint) ClearSrc() { 105 if e.src != nil { 106 // Truncate src, no need to reallocate. 107 e.src = e.src[:0] 108 } 109 } 110 111 func (e *StdNetEndpoint) DstIP() netip.Addr { 112 return e.AddrPort.Addr() 113 } 114 115 // See control_default,linux, etc for implementations of SrcIP and SrcIfidx. 116 117 func (e *StdNetEndpoint) DstToBytes() []byte { 118 b, _ := e.AddrPort.MarshalBinary() 119 return b 120 } 121 122 func (e *StdNetEndpoint) DstToString() string { 123 return e.AddrPort.String() 124 } 125 126 func listenNet(listener Listener, network string, port int) (*net.UDPConn, int, error) { 127 conn, err := listener.ListenPacketCompat(network, ":"+strconv.Itoa(port)) 128 if err != nil { 129 return nil, 0, err 130 } 131 132 // Retrieve port. 133 laddr := conn.LocalAddr() 134 uaddr, err := net.ResolveUDPAddr( 135 laddr.Network(), 136 laddr.String(), 137 ) 138 if err != nil { 139 return nil, 0, err 140 } 141 return conn.(*net.UDPConn), uaddr.Port, nil 142 } 143 144 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 145 s.mu.Lock() 146 defer s.mu.Unlock() 147 148 var err error 149 var tries int 150 151 if s.ipv4 != nil || s.ipv6 != nil { 152 return nil, 0, ErrBindAlreadyOpen 153 } 154 155 // Attempt to open ipv4 and ipv6 listeners on the same port. 156 // If uport is 0, we can retry on failure. 157 again: 158 port := int(uport) 159 var v4conn, v6conn *net.UDPConn 160 var v4pc *ipv4.PacketConn 161 var v6pc *ipv6.PacketConn 162 163 v4conn, port, err = listenNet(s.listener, "udp4", port) 164 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 165 return nil, 0, err 166 } 167 168 // Listen on the same port as we're using for ipv4. 169 v6conn, port, err = listenNet(s.listener, "udp6", port) 170 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 171 v4conn.Close() 172 tries++ 173 goto again 174 } 175 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 176 v4conn.Close() 177 return nil, 0, err 178 } 179 var fns []ReceiveFunc 180 if v4conn != nil { 181 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) 182 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 183 v4pc = ipv4.NewPacketConn(v4conn) 184 s.ipv4PC = v4pc 185 } 186 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) 187 s.ipv4 = v4conn 188 } 189 if v6conn != nil { 190 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) 191 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 192 v6pc = ipv6.NewPacketConn(v6conn) 193 s.ipv6PC = v6pc 194 } 195 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) 196 s.ipv6 = v6conn 197 } 198 if len(fns) == 0 { 199 return nil, 0, syscall.EAFNOSUPPORT 200 } 201 202 return fns, uint16(port), nil 203 } 204 205 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { 206 for i := range *msgs { 207 (*msgs)[i].OOB = (*msgs)[i].OOB[:0] 208 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} 209 } 210 s.msgsPool.Put(msgs) 211 } 212 213 func (s *StdNetBind) getMessages() *[]ipv6.Message { 214 return s.msgsPool.Get().(*[]ipv6.Message) 215 } 216 217 // If compilation fails here these are no longer the same underlying type. 218 var _ ipv6.Message = ipv4.Message{} 219 220 type batchReader interface { 221 ReadBatch([]ipv6.Message, int) (int, error) 222 } 223 224 type batchWriter interface { 225 WriteBatch([]ipv6.Message, int) (int, error) 226 } 227 228 func (s *StdNetBind) receiveIP( 229 br batchReader, 230 conn *net.UDPConn, 231 rxOffload bool, 232 bufs [][]byte, 233 sizes []int, 234 eps []Endpoint, 235 ) (n int, err error) { 236 msgs := s.getMessages() 237 for i := range bufs { 238 (*msgs)[i].Buffers[0] = bufs[i] 239 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] 240 } 241 defer s.putMessages(msgs) 242 var numMsgs int 243 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 244 if rxOffload { 245 readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) 246 numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) 247 if err != nil { 248 return 0, err 249 } 250 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) 251 if err != nil { 252 return 0, err 253 } 254 } else { 255 numMsgs, err = br.ReadBatch(*msgs, 0) 256 if err != nil { 257 return 0, err 258 } 259 } 260 } else { 261 msg := &(*msgs)[0] 262 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 263 if err != nil { 264 return 0, err 265 } 266 numMsgs = 1 267 } 268 for i := 0; i < numMsgs; i++ { 269 msg := &(*msgs)[i] 270 sizes[i] = msg.N 271 if sizes[i] == 0 { 272 continue 273 } 274 if msg.N > 3 { 275 common.ClearArray(bufs[i][1:4]) 276 } 277 ep := &StdNetEndpoint{AddrPort: M.AddrPortFromNet(msg.Addr)} // TODO: remove allocation 278 getSrcFromControl(msg.OOB[:msg.NN], ep) 279 eps[i] = ep 280 } 281 return numMsgs, nil 282 } 283 284 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 285 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 286 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 287 } 288 } 289 290 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 291 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 292 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 293 } 294 } 295 296 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 297 // rename the IdealBatchSize constant to BatchSize. 298 func (s *StdNetBind) BatchSize() int { 299 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 300 return IdealBatchSize 301 } 302 return 1 303 } 304 305 func (s *StdNetBind) Close() error { 306 s.mu.Lock() 307 defer s.mu.Unlock() 308 309 var err1, err2 error 310 if s.ipv4 != nil { 311 err1 = s.ipv4.Close() 312 s.ipv4 = nil 313 s.ipv4PC = nil 314 } 315 if s.ipv6 != nil { 316 err2 = s.ipv6.Close() 317 s.ipv6 = nil 318 s.ipv6PC = nil 319 } 320 s.blackhole4 = false 321 s.blackhole6 = false 322 s.ipv4TxOffload = false 323 s.ipv4RxOffload = false 324 s.ipv6TxOffload = false 325 s.ipv6RxOffload = false 326 if err1 != nil { 327 return err1 328 } 329 return err2 330 } 331 332 type ErrUDPGSODisabled struct { 333 onLaddr string 334 RetryErr error 335 } 336 337 func (e ErrUDPGSODisabled) Error() string { 338 return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) 339 } 340 341 func (e ErrUDPGSODisabled) Unwrap() error { 342 return e.RetryErr 343 } 344 345 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { 346 s.mu.Lock() 347 blackhole := s.blackhole4 348 conn := s.ipv4 349 offload := s.ipv4TxOffload 350 br := batchWriter(s.ipv4PC) 351 is6 := false 352 if endpoint.DstIP().Is6() { 353 blackhole = s.blackhole6 354 conn = s.ipv6 355 br = s.ipv6PC 356 is6 = true 357 offload = s.ipv6TxOffload 358 } 359 s.mu.Unlock() 360 361 if blackhole { 362 return nil 363 } 364 if conn == nil { 365 return syscall.EAFNOSUPPORT 366 } 367 368 msgs := s.getMessages() 369 defer s.putMessages(msgs) 370 ua := s.udpAddrPool.Get().(*net.UDPAddr) 371 defer s.udpAddrPool.Put(ua) 372 if is6 { 373 as16 := endpoint.DstIP().As16() 374 copy(ua.IP, as16[:]) 375 ua.IP = ua.IP[:16] 376 } else { 377 as4 := endpoint.DstIP().As4() 378 copy(ua.IP, as4[:]) 379 ua.IP = ua.IP[:4] 380 } 381 ua.Port = int(endpoint.(*StdNetEndpoint).Port()) 382 var ( 383 retried bool 384 err error 385 ) 386 retry: 387 if offload { 388 n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) 389 err = s.send(conn, br, (*msgs)[:n]) 390 if err != nil && offload && errShouldDisableUDPGSO(err) { 391 offload = false 392 s.mu.Lock() 393 if is6 { 394 s.ipv6TxOffload = false 395 } else { 396 s.ipv4TxOffload = false 397 } 398 s.mu.Unlock() 399 retried = true 400 goto retry 401 } 402 } else { 403 for i := range bufs { 404 bufI := bufs[i] 405 if len(bufI) > 3 { 406 reserved, loaded := s.reservedForEndpoint[endpoint.(*StdNetEndpoint).AddrPort] 407 if loaded { 408 copy(bufI[1:4], reserved[:]) 409 } 410 } 411 412 (*msgs)[i].Addr = ua 413 (*msgs)[i].Buffers[0] = bufs[i] 414 setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) 415 } 416 err = s.send(conn, br, (*msgs)[:len(bufs)]) 417 } 418 if retried { 419 return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} 420 } 421 return err 422 } 423 424 func (s *StdNetBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) { 425 s.reservedForEndpoint[destination] = reserved 426 } 427 428 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { 429 var ( 430 n int 431 err error 432 start int 433 ) 434 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 435 for { 436 n, err = pc.WriteBatch(msgs[start:], 0) 437 if err != nil || n == len(msgs[start:]) { 438 break 439 } 440 start += n 441 } 442 } else { 443 for _, msg := range msgs { 444 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) 445 if err != nil { 446 break 447 } 448 } 449 } 450 return err 451 } 452 453 const ( 454 // Exceeding these values results in EMSGSIZE. They account for layer3 and 455 // layer4 headers. IPv6 does not need to account for itself as the payload 456 // length field is self excluding. 457 maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 458 maxIPv6PayloadLen = 1<<16 - 1 - 8 459 460 // This is a hard limit imposed by the kernel. 461 udpSegmentMaxDatagrams = 64 462 ) 463 464 type setGSOFunc func(control *[]byte, gsoSize uint16) 465 466 func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { 467 var ( 468 base = -1 // index of msg we are currently coalescing into 469 gsoSize int // segmentation size of msgs[base] 470 dgramCnt int // number of dgrams coalesced into msgs[base] 471 endBatch bool // tracking flag to start a new batch on next iteration of bufs 472 ) 473 maxPayloadLen := maxIPv4PayloadLen 474 if ep.DstIP().Is6() { 475 maxPayloadLen = maxIPv6PayloadLen 476 } 477 for i, buf := range bufs { 478 if i > 0 { 479 msgLen := len(buf) 480 baseLenBefore := len(msgs[base].Buffers[0]) 481 freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore 482 if msgLen+baseLenBefore <= maxPayloadLen && 483 msgLen <= gsoSize && 484 msgLen <= freeBaseCap && 485 dgramCnt < udpSegmentMaxDatagrams && 486 !endBatch { 487 msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) 488 if i == len(bufs)-1 { 489 setGSO(&msgs[base].OOB, uint16(gsoSize)) 490 } 491 dgramCnt++ 492 if msgLen < gsoSize { 493 // A smaller than gsoSize packet on the tail is legal, but 494 // it must end the batch. 495 endBatch = true 496 } 497 continue 498 } 499 } 500 if dgramCnt > 1 { 501 setGSO(&msgs[base].OOB, uint16(gsoSize)) 502 } 503 // Reset prior to incrementing base since we are preparing to start a 504 // new potential batch. 505 endBatch = false 506 base++ 507 gsoSize = len(buf) 508 setSrcControl(&msgs[base].OOB, ep) 509 msgs[base].Buffers[0] = buf 510 msgs[base].Addr = addr 511 dgramCnt = 1 512 } 513 return base + 1 514 } 515 516 type getGSOFunc func(control []byte) (int, error) 517 518 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { 519 for i := firstMsgAt; i < len(msgs); i++ { 520 msg := &msgs[i] 521 if msg.N == 0 { 522 return n, err 523 } 524 var ( 525 gsoSize int 526 start int 527 end = msg.N 528 numToSplit = 1 529 ) 530 gsoSize, err = getGSO(msg.OOB[:msg.NN]) 531 if err != nil { 532 return n, err 533 } 534 if gsoSize > 0 { 535 numToSplit = (msg.N + gsoSize - 1) / gsoSize 536 end = gsoSize 537 } 538 for j := 0; j < numToSplit; j++ { 539 if n > i { 540 return n, errors.New("splitting coalesced packet resulted in overflow") 541 } 542 copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) 543 msgs[n].N = copied 544 msgs[n].Addr = msg.Addr 545 start = end 546 end += gsoSize 547 if end > msg.N { 548 end = msg.N 549 } 550 n++ 551 } 552 if i != n-1 { 553 // It is legal for bytes to move within msg.Buffers[0] as a result 554 // of splitting, so we only zero the source msg len when it is not 555 // the destination of the last split operation above. 556 msg.N = 0 557 } 558 } 559 return n, nil 560 }