github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "context" 10 "errors" 11 "net" 12 "net/netip" 13 "runtime" 14 "strconv" 15 "sync" 16 "syscall" 17 18 "golang.org/x/net/ipv4" 19 "golang.org/x/net/ipv6" 20 ) 21 22 var ( 23 _ Bind = (*StdNetBind)(nil) 24 ) 25 26 // StdNetBind implements Bind for all platforms. While Windows has its own Bind 27 // (see bind_windows.go), it may fall back to StdNetBind. 28 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable 29 // methods for sending and receiving multiple datagrams per-syscall. See the 30 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. 31 type StdNetBind struct { 32 mu sync.Mutex // protects all fields except as specified 33 ipv4 *net.UDPConn 34 ipv6 *net.UDPConn 35 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 36 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 37 38 // these three fields are not guarded by mu 39 udpAddrPool sync.Pool 40 ipv4MsgsPool sync.Pool 41 ipv6MsgsPool sync.Pool 42 43 blackhole4 bool 44 blackhole6 bool 45 } 46 47 func NewStdNetBind() Bind { 48 return &StdNetBind{ 49 udpAddrPool: sync.Pool{ 50 New: func() any { 51 return &net.UDPAddr{ 52 IP: make([]byte, 16), 53 } 54 }, 55 }, 56 57 ipv4MsgsPool: sync.Pool{ 58 New: func() any { 59 msgs := make([]ipv4.Message, IdealBatchSize) 60 for i := range msgs { 61 msgs[i].Buffers = make(net.Buffers, 1) 62 msgs[i].OOB = make([]byte, srcControlSize) 63 } 64 return &msgs 65 }, 66 }, 67 68 ipv6MsgsPool: sync.Pool{ 69 New: func() any { 70 msgs := make([]ipv6.Message, IdealBatchSize) 71 for i := range msgs { 72 msgs[i].Buffers = make(net.Buffers, 1) 73 msgs[i].OOB = make([]byte, srcControlSize) 74 } 75 return &msgs 76 }, 77 }, 78 } 79 } 80 81 type StdNetEndpoint struct { 82 // AddrPort is the endpoint destination. 83 netip.AddrPort 84 // src is the current sticky source address and interface index, if 85 // supported. Typically this is a PKTINFO structure from/for control 86 // messages, see unix.PKTINFO for an example. 87 src []byte 88 } 89 90 var ( 91 _ Bind = (*StdNetBind)(nil) 92 _ Endpoint = &StdNetEndpoint{} 93 ) 94 95 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 96 e, err := netip.ParseAddrPort(s) 97 if err != nil { 98 return nil, err 99 } 100 return &StdNetEndpoint{ 101 AddrPort: e, 102 }, nil 103 } 104 105 func (e *StdNetEndpoint) ClearSrc() { 106 if e.src != nil { 107 // Truncate src, no need to reallocate. 108 e.src = e.src[:0] 109 } 110 } 111 112 func (e *StdNetEndpoint) DstIP() netip.Addr { 113 return e.AddrPort.Addr() 114 } 115 116 // See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx. 117 118 func (e *StdNetEndpoint) DstToBytes() []byte { 119 b, _ := e.AddrPort.MarshalBinary() 120 return b 121 } 122 123 func (e *StdNetEndpoint) DstToString() string { 124 return e.AddrPort.String() 125 } 126 127 func listenNet(network string, port int) (*net.UDPConn, int, error) { 128 conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) 129 if err != nil { 130 return nil, 0, err 131 } 132 133 // Retrieve port. 134 laddr := conn.LocalAddr() 135 uaddr, err := net.ResolveUDPAddr( 136 laddr.Network(), 137 laddr.String(), 138 ) 139 if err != nil { 140 return nil, 0, err 141 } 142 return conn.(*net.UDPConn), uaddr.Port, nil 143 } 144 145 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 146 s.mu.Lock() 147 defer s.mu.Unlock() 148 149 var err error 150 var tries int 151 152 if s.ipv4 != nil || s.ipv6 != nil { 153 return nil, 0, ErrBindAlreadyOpen 154 } 155 156 // Attempt to open ipv4 and ipv6 listeners on the same port. 157 // If uport is 0, we can retry on failure. 158 again: 159 port := int(uport) 160 var v4conn, v6conn *net.UDPConn 161 var v4pc *ipv4.PacketConn 162 var v6pc *ipv6.PacketConn 163 164 v4conn, port, err = listenNet("udp4", port) 165 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 166 return nil, 0, err 167 } 168 169 // Listen on the same port as we're using for ipv4. 170 v6conn, port, err = listenNet("udp6", port) 171 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 172 v4conn.Close() 173 tries++ 174 goto again 175 } 176 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 177 v4conn.Close() 178 return nil, 0, err 179 } 180 var fns []ReceiveFunc 181 if v4conn != nil { 182 if runtime.GOOS == "linux" { 183 v4pc = ipv4.NewPacketConn(v4conn) 184 s.ipv4PC = v4pc 185 } 186 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) 187 s.ipv4 = v4conn 188 } 189 if v6conn != nil { 190 if runtime.GOOS == "linux" { 191 v6pc = ipv6.NewPacketConn(v6conn) 192 s.ipv6PC = v6pc 193 } 194 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) 195 s.ipv6 = v6conn 196 } 197 if len(fns) == 0 { 198 return nil, 0, syscall.EAFNOSUPPORT 199 } 200 201 return fns, uint16(port), nil 202 } 203 204 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { 205 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 206 msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) 207 defer s.ipv4MsgsPool.Put(msgs) 208 for i := range bufs { 209 (*msgs)[i].Buffers[0] = bufs[i] 210 } 211 var numMsgs int 212 if runtime.GOOS == "linux" { 213 numMsgs, err = pc.ReadBatch(*msgs, 0) 214 if err != nil { 215 return 0, err 216 } 217 } else { 218 msg := &(*msgs)[0] 219 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 220 if err != nil { 221 return 0, err 222 } 223 numMsgs = 1 224 } 225 for i := 0; i < numMsgs; i++ { 226 msg := &(*msgs)[i] 227 sizes[i] = msg.N 228 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 229 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 230 getSrcFromControl(msg.OOB[:msg.NN], ep) 231 eps[i] = ep 232 } 233 return numMsgs, nil 234 } 235 } 236 237 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { 238 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { 239 msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) 240 defer s.ipv6MsgsPool.Put(msgs) 241 for i := range bufs { 242 (*msgs)[i].Buffers[0] = bufs[i] 243 } 244 var numMsgs int 245 if runtime.GOOS == "linux" { 246 numMsgs, err = pc.ReadBatch(*msgs, 0) 247 if err != nil { 248 return 0, err 249 } 250 } else { 251 msg := &(*msgs)[0] 252 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 253 if err != nil { 254 return 0, err 255 } 256 numMsgs = 1 257 } 258 for i := 0; i < numMsgs; i++ { 259 msg := &(*msgs)[i] 260 sizes[i] = msg.N 261 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 262 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 263 getSrcFromControl(msg.OOB[:msg.NN], ep) 264 eps[i] = ep 265 } 266 return numMsgs, nil 267 } 268 } 269 270 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 271 // rename the IdealBatchSize constant to BatchSize. 272 func (s *StdNetBind) BatchSize() int { 273 if runtime.GOOS == "linux" { 274 return IdealBatchSize 275 } 276 return 1 277 } 278 279 func (s *StdNetBind) Close() error { 280 s.mu.Lock() 281 defer s.mu.Unlock() 282 283 var err1, err2 error 284 if s.ipv4 != nil { 285 err1 = s.ipv4.Close() 286 s.ipv4 = nil 287 s.ipv4PC = nil 288 } 289 if s.ipv6 != nil { 290 err2 = s.ipv6.Close() 291 s.ipv6 = nil 292 s.ipv6PC = nil 293 } 294 s.blackhole4 = false 295 s.blackhole6 = false 296 if err1 != nil { 297 return err1 298 } 299 return err2 300 } 301 302 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { 303 s.mu.Lock() 304 blackhole := s.blackhole4 305 conn := s.ipv4 306 var ( 307 pc4 *ipv4.PacketConn 308 pc6 *ipv6.PacketConn 309 ) 310 is6 := false 311 if endpoint.DstIP().Is6() { 312 blackhole = s.blackhole6 313 conn = s.ipv6 314 pc6 = s.ipv6PC 315 is6 = true 316 } else { 317 pc4 = s.ipv4PC 318 } 319 s.mu.Unlock() 320 321 if blackhole { 322 return nil 323 } 324 if conn == nil { 325 return syscall.EAFNOSUPPORT 326 } 327 if is6 { 328 return s.send6(conn, pc6, endpoint, bufs) 329 } else { 330 return s.send4(conn, pc4, endpoint, bufs) 331 } 332 } 333 334 func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { 335 ua := s.udpAddrPool.Get().(*net.UDPAddr) 336 as4 := ep.DstIP().As4() 337 copy(ua.IP, as4[:]) 338 ua.IP = ua.IP[:4] 339 ua.Port = int(ep.(*StdNetEndpoint).Port()) 340 msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) 341 for i, buf := range bufs { 342 (*msgs)[i].Buffers[0] = buf 343 (*msgs)[i].Addr = ua 344 setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) 345 } 346 var ( 347 n int 348 err error 349 start int 350 ) 351 if runtime.GOOS == "linux" { 352 for { 353 n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) 354 if err != nil || n == len((*msgs)[start:len(bufs)]) { 355 break 356 } 357 start += n 358 } 359 } else { 360 for i, buf := range bufs { 361 _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) 362 if err != nil { 363 break 364 } 365 } 366 } 367 s.udpAddrPool.Put(ua) 368 s.ipv4MsgsPool.Put(msgs) 369 return err 370 } 371 372 func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { 373 ua := s.udpAddrPool.Get().(*net.UDPAddr) 374 as16 := ep.DstIP().As16() 375 copy(ua.IP, as16[:]) 376 ua.IP = ua.IP[:16] 377 ua.Port = int(ep.(*StdNetEndpoint).Port()) 378 msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) 379 for i, buf := range bufs { 380 (*msgs)[i].Buffers[0] = buf 381 (*msgs)[i].Addr = ua 382 setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) 383 } 384 var ( 385 n int 386 err error 387 start int 388 ) 389 if runtime.GOOS == "linux" { 390 for { 391 n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) 392 if err != nil || n == len((*msgs)[start:len(bufs)]) { 393 break 394 } 395 start += n 396 } 397 } else { 398 for i, buf := range bufs { 399 _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) 400 if err != nil { 401 break 402 } 403 } 404 } 405 s.udpAddrPool.Put(ua) 406 s.ipv6MsgsPool.Put(msgs) 407 return err 408 }