github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "context" 10 "errors" 11 "fmt" 12 "net" 13 "net/netip" 14 "runtime" 15 "strconv" 16 "sync" 17 "syscall" 18 19 "golang.org/x/net/ipv4" 20 "golang.org/x/net/ipv6" 21 ) 22 23 var ( 24 _ Bind = (*StdNetBind)(nil) 25 ) 26 27 // StdNetBind implements Bind for all platforms. While Windows has its own Bind 28 // (see bind_windows.go), it may fall back to StdNetBind. 29 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable 30 // methods for sending and receiving multiple datagrams per-syscall. See the 31 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. 32 type StdNetBind struct { 33 mu sync.Mutex // protects all fields except as specified 34 ipv4 *net.UDPConn 35 ipv6 *net.UDPConn 36 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 37 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 38 ipv4TxOffload bool 39 ipv4RxOffload bool 40 ipv6TxOffload bool 41 ipv6RxOffload bool 42 43 // these two fields are not guarded by mu 44 udpAddrPool sync.Pool 45 msgsPool sync.Pool 46 47 blackhole4 bool 48 blackhole6 bool 49 } 50 51 func NewStdNetBind() Bind { 52 return &StdNetBind{ 53 udpAddrPool: sync.Pool{ 54 New: func() any { 55 return &net.UDPAddr{ 56 IP: make([]byte, 16), 57 } 58 }, 59 }, 60 61 msgsPool: sync.Pool{ 62 New: func() any { 63 // ipv6.Message and ipv4.Message are interchangeable as they are 64 // both aliases for x/net/internal/socket.Message. 65 msgs := make([]ipv6.Message, IdealBatchSize) 66 for i := range msgs { 67 msgs[i].Buffers = make(net.Buffers, 1) 68 msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) 69 } 70 return &msgs 71 }, 72 }, 73 } 74 } 75 76 type StdNetEndpoint struct { 77 // AddrPort is the endpoint destination. 78 netip.AddrPort 79 // src is the current sticky source address and interface index, if 80 // supported. Typically this is a PKTINFO structure from/for control 81 // messages, see unix.PKTINFO for an example. 82 src []byte 83 } 84 85 var ( 86 _ Bind = (*StdNetBind)(nil) 87 _ Endpoint = &StdNetEndpoint{} 88 ) 89 90 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 91 e, err := netip.ParseAddrPort(s) 92 if err != nil { 93 return nil, err 94 } 95 return &StdNetEndpoint{ 96 AddrPort: e, 97 }, nil 98 } 99 100 func (e *StdNetEndpoint) ClearSrc() { 101 if e.src != nil { 102 // Truncate src, no need to reallocate. 103 e.src = e.src[:0] 104 } 105 } 106 107 func (e *StdNetEndpoint) DstIP() netip.Addr { 108 return e.AddrPort.Addr() 109 } 110 111 // See control_default,linux, etc for implementations of SrcIP and SrcIfidx. 112 113 func (e *StdNetEndpoint) DstToBytes() []byte { 114 b, _ := e.AddrPort.MarshalBinary() 115 return b 116 } 117 118 func (e *StdNetEndpoint) DstToString() string { 119 return e.AddrPort.String() 120 } 121 122 func listenNet(network string, port int) (*net.UDPConn, int, error) { 123 conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) 124 if err != nil { 125 return nil, 0, err 126 } 127 128 // Retrieve port. 129 laddr := conn.LocalAddr() 130 uaddr, err := net.ResolveUDPAddr( 131 laddr.Network(), 132 laddr.String(), 133 ) 134 if err != nil { 135 return nil, 0, err 136 } 137 return conn.(*net.UDPConn), uaddr.Port, nil 138 } 139 140 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 141 s.mu.Lock() 142 defer s.mu.Unlock() 143 144 var err error 145 var tries int 146 147 if s.ipv4 != nil || s.ipv6 != nil { 148 return nil, 0, ErrBindAlreadyOpen 149 } 150 151 // Attempt to open ipv4 and ipv6 listeners on the same port. 152 // If uport is 0, we can retry on failure. 153 again: 154 port := int(uport) 155 var v4conn, v6conn *net.UDPConn 156 var v4pc *ipv4.PacketConn 157 var v6pc *ipv6.PacketConn 158 159 v4conn, port, err = listenNet("udp4", port) 160 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 161 return nil, 0, err 162 } 163 164 // Listen on the same port as we're using for ipv4. 165 v6conn, port, err = listenNet("udp6", port) 166 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 167 v4conn.Close() 168 tries++ 169 goto again 170 } 171 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 172 v4conn.Close() 173 return nil, 0, err 174 } 175 var fns []ReceiveFunc 176 if v4conn != nil { 177 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) 178 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 179 v4pc = ipv4.NewPacketConn(v4conn) 180 s.ipv4PC = v4pc 181 } 182 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) 183 s.ipv4 = v4conn 184 } 185 if v6conn != nil { 186 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) 187 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 188 v6pc = ipv6.NewPacketConn(v6conn) 189 s.ipv6PC = v6pc 190 } 191 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) 192 s.ipv6 = v6conn 193 } 194 if len(fns) == 0 { 195 return nil, 0, syscall.EAFNOSUPPORT 196 } 197 198 return fns, uint16(port), nil 199 } 200 201 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { 202 for i := range *msgs { 203 (*msgs)[i].OOB = (*msgs)[i].OOB[:0] 204 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} 205 } 206 s.msgsPool.Put(msgs) 207 } 208 209 func (s *StdNetBind) getMessages() *[]ipv6.Message { 210 return s.msgsPool.Get().(*[]ipv6.Message) 211 } 212 213 var ( 214 // If compilation fails here these are no longer the same underlying type. 215 _ ipv6.Message = ipv4.Message{} 216 ) 217 218 type batchReader interface { 219 ReadBatch([]ipv6.Message, int) (int, error) 220 } 221 222 type batchWriter interface { 223 WriteBatch([]ipv6.Message, int) (int, error) 224 } 225 226 func (s *StdNetBind) receiveIP( 227 br batchReader, 228 conn *net.UDPConn, 229 rxOffload bool, 230 bufs [][]byte, 231 sizes []int, 232 eps []Endpoint, 233 ) (n int, err error) { 234 msgs := s.getMessages() 235 for i := range bufs { 236 (*msgs)[i].Buffers[0] = bufs[i] 237 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] 238 } 239 defer s.putMessages(msgs) 240 var numMsgs int 241 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 242 if rxOffload { 243 readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) 244 numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) 245 if err != nil { 246 return 0, err 247 } 248 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) 249 if err != nil { 250 return 0, err 251 } 252 } else { 253 numMsgs, err = br.ReadBatch(*msgs, 0) 254 if err != nil { 255 return 0, err 256 } 257 } 258 } else { 259 msg := &(*msgs)[0] 260 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 261 if err != nil { 262 return 0, err 263 } 264 numMsgs = 1 265 } 266 for i := 0; i < numMsgs; i++ { 267 msg := &(*msgs)[i] 268 sizes[i] = msg.N 269 if sizes[i] == 0 { 270 continue 271 } 272 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 273 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 274 getSrcFromControl(msg.OOB[:msg.NN], ep) 275 eps[i] = ep 276 } 277 return numMsgs, nil 278 } 279 280 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 281 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 282 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 283 } 284 } 285 286 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { 287 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 288 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 289 } 290 } 291 292 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 293 // rename the IdealBatchSize constant to BatchSize. 294 func (s *StdNetBind) BatchSize() int { 295 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 296 return IdealBatchSize 297 } 298 return 1 299 } 300 301 func (s *StdNetBind) GetOffloadInfo() string { 302 return fmt.Sprintf("ipv4TxOffload: %v, ipv4RxOffload: %v\nipv6TxOffload: %v, ipv6RxOffload: %v", 303 s.ipv4TxOffload, s.ipv4RxOffload, s.ipv6TxOffload, s.ipv6RxOffload) 304 } 305 306 func (s *StdNetBind) Close() error { 307 s.mu.Lock() 308 defer s.mu.Unlock() 309 310 var err1, err2 error 311 if s.ipv4 != nil { 312 err1 = s.ipv4.Close() 313 s.ipv4 = nil 314 s.ipv4PC = nil 315 } 316 if s.ipv6 != nil { 317 err2 = s.ipv6.Close() 318 s.ipv6 = nil 319 s.ipv6PC = nil 320 } 321 s.blackhole4 = false 322 s.blackhole6 = false 323 s.ipv4TxOffload = false 324 s.ipv4RxOffload = false 325 s.ipv6TxOffload = false 326 s.ipv6RxOffload = false 327 if err1 != nil { 328 return err1 329 } 330 return err2 331 } 332 333 type ErrUDPGSODisabled struct { 334 onLaddr string 335 RetryErr error 336 } 337 338 func (e ErrUDPGSODisabled) Error() string { 339 return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr) 340 } 341 342 func (e ErrUDPGSODisabled) Unwrap() error { 343 return e.RetryErr 344 } 345 346 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { 347 s.mu.Lock() 348 blackhole := s.blackhole4 349 conn := s.ipv4 350 offload := s.ipv4TxOffload 351 br := batchWriter(s.ipv4PC) 352 is6 := false 353 if endpoint.DstIP().Is6() { 354 blackhole = s.blackhole6 355 conn = s.ipv6 356 br = s.ipv6PC 357 is6 = true 358 offload = s.ipv6TxOffload 359 } 360 s.mu.Unlock() 361 362 if blackhole { 363 return nil 364 } 365 if conn == nil { 366 return syscall.EAFNOSUPPORT 367 } 368 369 msgs := s.getMessages() 370 defer s.putMessages(msgs) 371 ua := s.udpAddrPool.Get().(*net.UDPAddr) 372 defer s.udpAddrPool.Put(ua) 373 if is6 { 374 as16 := endpoint.DstIP().As16() 375 copy(ua.IP, as16[:]) 376 ua.IP = ua.IP[:16] 377 } else { 378 as4 := endpoint.DstIP().As4() 379 copy(ua.IP, as4[:]) 380 ua.IP = ua.IP[:4] 381 } 382 ua.Port = int(endpoint.(*StdNetEndpoint).Port()) 383 var ( 384 retried bool 385 err error 386 ) 387 retry: 388 if offload { 389 n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) 390 err = s.send(conn, br, (*msgs)[:n]) 391 if err != nil && offload && errShouldDisableUDPGSO(err) { 392 offload = false 393 s.mu.Lock() 394 if is6 { 395 s.ipv6TxOffload = false 396 } else { 397 s.ipv4TxOffload = false 398 } 399 s.mu.Unlock() 400 retried = true 401 goto retry 402 } 403 } else { 404 for i := range bufs { 405 (*msgs)[i].Addr = ua 406 (*msgs)[i].Buffers[0] = bufs[i] 407 setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) 408 } 409 err = s.send(conn, br, (*msgs)[:len(bufs)]) 410 } 411 if retried { 412 return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} 413 } 414 return err 415 } 416 417 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { 418 var ( 419 n int 420 err error 421 start int 422 ) 423 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 424 for { 425 n, err = pc.WriteBatch(msgs[start:], 0) 426 if err != nil || n == len(msgs[start:]) { 427 break 428 } 429 start += n 430 } 431 } else { 432 for _, msg := range msgs { 433 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) 434 if err != nil { 435 break 436 } 437 } 438 } 439 return err 440 } 441 442 const ( 443 // Exceeding these values results in EMSGSIZE. They account for layer3 and 444 // layer4 headers. IPv6 does not need to account for itself as the payload 445 // length field is self excluding. 446 maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 447 maxIPv6PayloadLen = 1<<16 - 1 - 8 448 449 // This is a hard limit imposed by the kernel. 450 udpSegmentMaxDatagrams = 64 451 ) 452 453 type setGSOFunc func(control *[]byte, gsoSize uint16) 454 455 func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { 456 var ( 457 base = -1 // index of msg we are currently coalescing into 458 gsoSize int // segmentation size of msgs[base] 459 dgramCnt int // number of dgrams coalesced into msgs[base] 460 endBatch bool // tracking flag to start a new batch on next iteration of bufs 461 ) 462 maxPayloadLen := maxIPv4PayloadLen 463 if ep.DstIP().Is6() { 464 maxPayloadLen = maxIPv6PayloadLen 465 } 466 for i, buf := range bufs { 467 if i > 0 { 468 msgLen := len(buf) 469 baseLenBefore := len(msgs[base].Buffers[0]) 470 freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore 471 if msgLen+baseLenBefore <= maxPayloadLen && 472 msgLen <= gsoSize && 473 msgLen <= freeBaseCap && 474 dgramCnt < udpSegmentMaxDatagrams && 475 !endBatch { 476 msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) 477 if i == len(bufs)-1 { 478 setGSO(&msgs[base].OOB, uint16(gsoSize)) 479 } 480 dgramCnt++ 481 if msgLen < gsoSize { 482 // A smaller than gsoSize packet on the tail is legal, but 483 // it must end the batch. 484 endBatch = true 485 } 486 continue 487 } 488 } 489 if dgramCnt > 1 { 490 setGSO(&msgs[base].OOB, uint16(gsoSize)) 491 } 492 // Reset prior to incrementing base since we are preparing to start a 493 // new potential batch. 494 endBatch = false 495 base++ 496 gsoSize = len(buf) 497 setSrcControl(&msgs[base].OOB, ep) 498 msgs[base].Buffers[0] = buf 499 msgs[base].Addr = addr 500 dgramCnt = 1 501 } 502 return base + 1 503 } 504 505 type getGSOFunc func(control []byte) (int, error) 506 507 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { 508 for i := firstMsgAt; i < len(msgs); i++ { 509 msg := &msgs[i] 510 if msg.N == 0 { 511 return n, err 512 } 513 var ( 514 gsoSize int 515 start int 516 end = msg.N 517 numToSplit = 1 518 ) 519 gsoSize, err = getGSO(msg.OOB[:msg.NN]) 520 if err != nil { 521 return n, err 522 } 523 if gsoSize > 0 { 524 numToSplit = (msg.N + gsoSize - 1) / gsoSize 525 end = gsoSize 526 } 527 for j := 0; j < numToSplit; j++ { 528 if n > i { 529 return n, errors.New("splitting coalesced packet resulted in overflow") 530 } 531 copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) 532 msgs[n].N = copied 533 msgs[n].Addr = msg.Addr 534 start = end 535 end += gsoSize 536 if end > msg.N { 537 end = msg.N 538 } 539 n++ 540 } 541 if i != n-1 { 542 // It is legal for bytes to move within msg.Buffers[0] as a result 543 // of splitting, so we only zero the source msg len when it is not 544 // the destination of the last split operation above. 545 msg.N = 0 546 } 547 } 548 return n, nil 549 }