github.com/amnezia-vpn/amnezia-wg@v0.1.8/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "context" 10 "errors" 11 "fmt" 12 "net" 13 "net/netip" 14 "runtime" 15 "strconv" 16 "sync" 17 "syscall" 18 19 "golang.org/x/net/ipv4" 20 "golang.org/x/net/ipv6" 21 ) 22 23 var ( 24 _ Bind = (*StdNetBind)(nil) 25 ) 26 27 // StdNetBind implements Bind for all platforms. While Windows has its own Bind 28 // (see bind_windows.go), it may fall back to StdNetBind. 29 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable 30 // methods for sending and receiving multiple datagrams per-syscall. See the 31 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. 32 type StdNetBind struct { 33 mu sync.Mutex // protects all fields except as specified 34 ipv4 *net.UDPConn 35 ipv6 *net.UDPConn 36 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 37 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 38 ipv4TxOffload bool 39 ipv4RxOffload bool 40 ipv6TxOffload bool 41 ipv6RxOffload bool 42 43 // these two fields are not guarded by mu 44 udpAddrPool sync.Pool 45 msgsPool sync.Pool 46 47 blackhole4 bool 48 blackhole6 bool 49 } 50 51 func NewStdNetBind() Bind { 52 return &StdNetBind{ 53 udpAddrPool: sync.Pool{ 54 New: func() any { 55 return &net.UDPAddr{ 56 IP: make([]byte, 16), 57 } 58 }, 59 }, 60 61 msgsPool: sync.Pool{ 62 New: func() any { 63 // ipv6.Message and ipv4.Message are interchangeable as they are 64 // both aliases for x/net/internal/socket.Message. 65 msgs := make([]ipv6.Message, IdealBatchSize) 66 for i := range msgs { 67 msgs[i].Buffers = make(net.Buffers, 1) 68 msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) 69 } 70 return &msgs 71 }, 72 }, 73 } 74 } 75 76 type StdNetEndpoint struct { 77 // AddrPort is the endpoint destination. 78 netip.AddrPort 79 // src is the current sticky source address and interface index, if 80 // supported. Typically this is a PKTINFO structure from/for control 81 // messages, see unix.PKTINFO for an example. 82 src []byte 83 } 84 85 var ( 86 _ Bind = (*StdNetBind)(nil) 87 _ Endpoint = &StdNetEndpoint{} 88 ) 89 90 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 91 e, err := netip.ParseAddrPort(s) 92 if err != nil { 93 return nil, err 94 } 95 return &StdNetEndpoint{ 96 AddrPort: e, 97 }, nil 98 } 99 100 func (e *StdNetEndpoint) ClearSrc() { 101 if e.src != nil { 102 // Truncate src, no need to reallocate. 103 e.src = e.src[:0] 104 } 105 } 106 107 func (e *StdNetEndpoint) DstIP() netip.Addr { 108 return e.AddrPort.Addr() 109 } 110 111 // See control_default,linux, etc for implementations of SrcIP and SrcIfidx. 112 113 func (e *StdNetEndpoint) DstToBytes() []byte { 114 b, _ := e.AddrPort.MarshalBinary() 115 return b 116 } 117 118 func (e *StdNetEndpoint) DstToString() string { 119 return e.AddrPort.String() 120 } 121 122 func listenNet(network string, port int) (*net.UDPConn, int, error) { 123 conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) 124 if err != nil { 125 return nil, 0, err 126 } 127 128 // Retrieve port. 129 laddr := conn.LocalAddr() 130 uaddr, err := net.ResolveUDPAddr( 131 laddr.Network(), 132 laddr.String(), 133 ) 134 if err != nil { 135 return nil, 0, err 136 } 137 return conn.(*net.UDPConn), uaddr.Port, nil 138 } 139 140 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 141 s.mu.Lock() 142 defer s.mu.Unlock() 143 144 var err error 145 var tries int 146 147 if s.ipv4 != nil || s.ipv6 != nil { 148 return nil, 0, ErrBindAlreadyOpen 149 } 150 151 // Attempt to open ipv4 and ipv6 listeners on the same port. 152 // If uport is 0, we can retry on failure. 153 again: 154 port := int(uport) 155 var v4conn, v6conn *net.UDPConn 156 var v4pc *ipv4.PacketConn 157 var v6pc *ipv6.PacketConn 158 159 v4conn, port, err = listenNet("udp4", port) 160 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 161 return nil, 0, err 162 } 163 164 // Listen on the same port as we're using for ipv4. 165 v6conn, port, err = listenNet("udp6", port) 166 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 167 v4conn.Close() 168 tries++ 169 goto again 170 } 171 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 172 v4conn.Close() 173 return nil, 0, err 174 } 175 var fns []ReceiveFunc 176 if v4conn != nil { 177 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) 178 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 179 v4pc = ipv4.NewPacketConn(v4conn) 180 s.ipv4PC = v4pc 181 } 182 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) 183 s.ipv4 = v4conn 184 } 185 if v6conn != nil { 186 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) 187 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 188 v6pc = ipv6.NewPacketConn(v6conn) 189 s.ipv6PC = v6pc 190 } 191 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) 192 s.ipv6 = v6conn 193 } 194 if len(fns) == 0 { 195 return nil, 0, syscall.EAFNOSUPPORT 196 } 197 198 return fns, uint16(port), nil 199 } 200 201 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { 202 for i := range *msgs { 203 (*msgs)[i].OOB = (*msgs)[i].OOB[:0] 204 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} 205 } 206 s.msgsPool.Put(msgs) 207 } 208 209 func (s *StdNetBind) getMessages() *[]ipv6.Message { 210 return s.msgsPool.Get().(*[]ipv6.Message) 211 } 212 213 var ( 214 // If compilation fails here these are no longer the same underlying type. 215 _ ipv6.Message = ipv4.Message{} 216 ) 217 218 type batchReader interface { 219 ReadBatch([]ipv6.Message, int) (int, error) 220 } 221 222 type batchWriter interface { 223 WriteBatch([]ipv6.Message, int) (int, error) 224 } 225 226 func (s *StdNetBind) receiveIP( 227 br batchReader, 228 conn *net.UDPConn, 229 rxOffload bool, 230 bufs [][]byte, 231 sizes []int, 232 eps []Endpoint, 233 ) (n int, err error) { 234 msgs := s.getMessages() 235 for i := range bufs { 236 (*msgs)[i].Buffers[0] = bufs[i] 237 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] 238 } 239 defer s.putMessages(msgs) 240 var numMsgs int 241 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 242 if rxOffload { 243 readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) 244 numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) 245 if err != nil { 246 return 0, err 247 } 248 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) 249 if err != nil { 250 return 0, err 251 } 252 } else { 253 numMsgs, err = br.ReadBatch(*msgs, 0) 254 if err != nil { 255 return 0, err 256 } 257 } 258 } else { 259 msg := &(*msgs)[0] 260 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 261 if err != nil { 262 return 0, err 263 } 264 numMsgs = 1 265 } 266 for i := 0; i < numMsgs; i++ { 267 msg := &(*msgs)[i] 268 sizes[i] = msg.N 269 if sizes[i] == 0 { 270 continue 271 } 272 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 273 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 274 getSrcFromControl(msg.OOB[:msg.NN], ep) 275 eps[i] = ep 276 } 277 return numMsgs, nil 278 } 279 280 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 281 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 282 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 283 } 284 } 285 286 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 287 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 288 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 289 } 290 } 291 292 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 293 // rename the IdealBatchSize constant to BatchSize. 294 func (s *StdNetBind) BatchSize() int { 295 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 296 return IdealBatchSize 297 } 298 return 1 299 } 300 301 func (s *StdNetBind) Close() error { 302 s.mu.Lock() 303 defer s.mu.Unlock() 304 305 var err1, err2 error 306 if s.ipv4 != nil { 307 err1 = s.ipv4.Close() 308 s.ipv4 = nil 309 s.ipv4PC = nil 310 } 311 if s.ipv6 != nil { 312 err2 = s.ipv6.Close() 313 s.ipv6 = nil 314 s.ipv6PC = nil 315 } 316 s.blackhole4 = false 317 s.blackhole6 = false 318 s.ipv4TxOffload = false 319 s.ipv4RxOffload = false 320 s.ipv6TxOffload = false 321 s.ipv6RxOffload = false 322 if err1 != nil { 323 return err1 324 } 325 return err2 326 } 327 328 type ErrUDPGSODisabled struct { 329 onLaddr string 330 RetryErr error 331 } 332 333 func (e ErrUDPGSODisabled) Error() string { 334 return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) 335 } 336 337 func (e ErrUDPGSODisabled) Unwrap() error { 338 return e.RetryErr 339 } 340 341 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { 342 s.mu.Lock() 343 blackhole := s.blackhole4 344 conn := s.ipv4 345 offload := s.ipv4TxOffload 346 br := batchWriter(s.ipv4PC) 347 is6 := false 348 if endpoint.DstIP().Is6() { 349 blackhole = s.blackhole6 350 conn = s.ipv6 351 br = s.ipv6PC 352 is6 = true 353 offload = s.ipv6TxOffload 354 } 355 s.mu.Unlock() 356 357 if blackhole { 358 return nil 359 } 360 if conn == nil { 361 return syscall.EAFNOSUPPORT 362 } 363 364 msgs := s.getMessages() 365 defer s.putMessages(msgs) 366 ua := s.udpAddrPool.Get().(*net.UDPAddr) 367 defer s.udpAddrPool.Put(ua) 368 if is6 { 369 as16 := endpoint.DstIP().As16() 370 copy(ua.IP, as16[:]) 371 ua.IP = ua.IP[:16] 372 } else { 373 as4 := endpoint.DstIP().As4() 374 copy(ua.IP, as4[:]) 375 ua.IP = ua.IP[:4] 376 } 377 ua.Port = int(endpoint.(*StdNetEndpoint).Port()) 378 var ( 379 retried bool 380 err error 381 ) 382 retry: 383 if offload { 384 n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) 385 err = s.send(conn, br, (*msgs)[:n]) 386 if err != nil && offload && errShouldDisableUDPGSO(err) { 387 offload = false 388 s.mu.Lock() 389 if is6 { 390 s.ipv6TxOffload = false 391 } else { 392 s.ipv4TxOffload = false 393 } 394 s.mu.Unlock() 395 retried = true 396 goto retry 397 } 398 } else { 399 for i := range bufs { 400 (*msgs)[i].Addr = ua 401 (*msgs)[i].Buffers[0] = bufs[i] 402 setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) 403 } 404 err = s.send(conn, br, (*msgs)[:len(bufs)]) 405 } 406 if retried { 407 return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} 408 } 409 return err 410 } 411 412 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { 413 var ( 414 n int 415 err error 416 start int 417 ) 418 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 419 for { 420 n, err = pc.WriteBatch(msgs[start:], 0) 421 if err != nil || n == len(msgs[start:]) { 422 break 423 } 424 start += n 425 } 426 } else { 427 for _, msg := range msgs { 428 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) 429 if err != nil { 430 break 431 } 432 } 433 } 434 return err 435 } 436 437 const ( 438 // Exceeding these values results in EMSGSIZE. They account for layer3 and 439 // layer4 headers. IPv6 does not need to account for itself as the payload 440 // length field is self excluding. 441 maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 442 maxIPv6PayloadLen = 1<<16 - 1 - 8 443 444 // This is a hard limit imposed by the kernel. 445 udpSegmentMaxDatagrams = 64 446 ) 447 448 type setGSOFunc func(control *[]byte, gsoSize uint16) 449 450 func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { 451 var ( 452 base = -1 // index of msg we are currently coalescing into 453 gsoSize int // segmentation size of msgs[base] 454 dgramCnt int // number of dgrams coalesced into msgs[base] 455 endBatch bool // tracking flag to start a new batch on next iteration of bufs 456 ) 457 maxPayloadLen := maxIPv4PayloadLen 458 if ep.DstIP().Is6() { 459 maxPayloadLen = maxIPv6PayloadLen 460 } 461 for i, buf := range bufs { 462 if i > 0 { 463 msgLen := len(buf) 464 baseLenBefore := len(msgs[base].Buffers[0]) 465 freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore 466 if msgLen+baseLenBefore <= maxPayloadLen && 467 msgLen <= gsoSize && 468 msgLen <= freeBaseCap && 469 dgramCnt < udpSegmentMaxDatagrams && 470 !endBatch { 471 msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) 472 if i == len(bufs)-1 { 473 setGSO(&msgs[base].OOB, uint16(gsoSize)) 474 } 475 dgramCnt++ 476 if msgLen < gsoSize { 477 // A smaller than gsoSize packet on the tail is legal, but 478 // it must end the batch. 479 endBatch = true 480 } 481 continue 482 } 483 } 484 if dgramCnt > 1 { 485 setGSO(&msgs[base].OOB, uint16(gsoSize)) 486 } 487 // Reset prior to incrementing base since we are preparing to start a 488 // new potential batch. 489 endBatch = false 490 base++ 491 gsoSize = len(buf) 492 setSrcControl(&msgs[base].OOB, ep) 493 msgs[base].Buffers[0] = buf 494 msgs[base].Addr = addr 495 dgramCnt = 1 496 } 497 return base + 1 498 } 499 500 type getGSOFunc func(control []byte) (int, error) 501 502 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { 503 for i := firstMsgAt; i < len(msgs); i++ { 504 msg := &msgs[i] 505 if msg.N == 0 { 506 return n, err 507 } 508 var ( 509 gsoSize int 510 start int 511 end = msg.N 512 numToSplit = 1 513 ) 514 gsoSize, err = getGSO(msg.OOB[:msg.NN]) 515 if err != nil { 516 return n, err 517 } 518 if gsoSize > 0 { 519 numToSplit = (msg.N + gsoSize - 1) / gsoSize 520 end = gsoSize 521 } 522 for j := 0; j < numToSplit; j++ { 523 if n > i { 524 return n, errors.New("splitting coalesced packet resulted in overflow") 525 } 526 copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) 527 msgs[n].N = copied 528 msgs[n].Addr = msg.Addr 529 start = end 530 end += gsoSize 531 if end > msg.N { 532 end = msg.N 533 } 534 n++ 535 } 536 if i != n-1 { 537 // It is legal for bytes to move within msg.Buffers[0] as a result 538 // of splitting, so we only zero the source msg len when it is not 539 // the destination of the last split operation above. 540 msg.N = 0 541 } 542 } 543 return n, nil 544 }