github.com/yaling888/clash@v1.53.0/transport/wireguard/bind_std.go (about) 1 //go:build !nogvisor 2 3 package wireguard 4 5 import ( 6 "context" 7 "errors" 8 "fmt" 9 "net" 10 "net/netip" 11 "runtime" 12 "strconv" 13 "sync" 14 "syscall" 15 _ "unsafe" 16 17 "golang.org/x/net/ipv4" 18 "golang.org/x/net/ipv6" 19 wg "golang.zx2c4.com/wireguard/conn" 20 ) 21 22 //go:linkname getSrcFromControl golang.zx2c4.com/wireguard/conn.getSrcFromControl 23 func getSrcFromControl(control []byte, ep *wg.StdNetEndpoint) 24 25 //go:linkname setSrcControl golang.zx2c4.com/wireguard/conn.setSrcControl 26 func setSrcControl(control *[]byte, ep *wg.StdNetEndpoint) 27 28 //go:linkname getGSOSize golang.zx2c4.com/wireguard/conn.getGSOSize 29 func getGSOSize(control []byte) (int, error) 30 31 //go:linkname setGSOSize golang.zx2c4.com/wireguard/conn.setGSOSize 32 func setGSOSize(control *[]byte, gsoSize uint16) 33 34 //go:linkname supportsUDPOffload golang.zx2c4.com/wireguard/conn.supportsUDPOffload 35 func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) 36 37 //go:linkname errShouldDisableUDPGSO golang.zx2c4.com/wireguard/conn.errShouldDisableUDPGSO 38 func errShouldDisableUDPGSO(err error) bool 39 40 //go:linkname coalesceMessages golang.zx2c4.com/wireguard/conn.coalesceMessages 41 func coalesceMessages(addr *net.UDPAddr, ep *wg.StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int 42 43 //go:linkname splitCoalescedMessages golang.zx2c4.com/wireguard/conn.splitCoalescedMessages 44 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) 45 46 const udpSegmentMaxDatagrams = 64 // This is a hard limit imposed by the kernel. 47 48 type setGSOFunc func(control *[]byte, gsoSize uint16) 49 50 type getGSOFunc func(control []byte) (int, error) 51 52 var _ wg.Bind = (*StdNetBind)(nil) 53 54 type StdNetBind struct { 55 mu sync.Mutex // protects all fields except as specified 56 ipv4 *net.UDPConn 57 ipv6 *net.UDPConn 58 ipv4PC *ipv4.PacketConn // will be nil on non-Linux 59 ipv6PC *ipv6.PacketConn // will be nil on non-Linux 60 ipv4TxOffload bool 61 ipv4RxOffload bool 62 ipv6TxOffload bool 63 ipv6RxOffload bool 64 65 // these two fields are not guarded by mu 66 udpAddrPool sync.Pool 67 msgsPool sync.Pool 68 69 blackhole4 bool 70 blackhole6 bool 71 72 controlFns []func(network, address string, c syscall.RawConn) error 73 interfaceName string 74 reserved []byte 75 } 76 77 func (s *StdNetBind) setReserved(b []byte) { 78 if len(b) < 4 || s.reserved == nil { 79 return 80 } 81 b[1] = s.reserved[0] 82 b[2] = s.reserved[1] 83 b[3] = s.reserved[2] 84 } 85 86 func (s *StdNetBind) resetReserved(b []byte) { 87 if len(b) < 4 { 88 return 89 } 90 b[1] = 0x00 91 b[2] = 0x00 92 b[3] = 0x00 93 } 94 95 func (s *StdNetBind) listenConfig() *net.ListenConfig { 96 return &net.ListenConfig{ 97 Control: func(network, address string, c syscall.RawConn) error { 98 for _, fn := range s.controlFns { 99 if err := fn(network, address, c); err != nil { 100 return err 101 } 102 } 103 return nil 104 }, 105 } 106 } 107 108 func (s *StdNetBind) listenNet(network string, port int) (*net.UDPConn, int, error) { 109 listenIP, err := getListenIP(network, s.interfaceName) 110 if err != nil { 111 return nil, 0, err 112 } 113 114 conn, err := s.listenConfig().ListenPacket(context.Background(), network, listenIP+":"+strconv.Itoa(port)) 115 if err != nil { 116 return nil, 0, err 117 } 118 119 // Retrieve port. 120 laddr := conn.LocalAddr() 121 uaddr, err := net.ResolveUDPAddr( 122 laddr.Network(), 123 laddr.String(), 124 ) 125 if err != nil { 126 return nil, 0, err 127 } 128 return conn.(*net.UDPConn), uaddr.Port, nil 129 } 130 131 func (s *StdNetBind) SetMark(mark uint32) error { 132 return nil 133 } 134 135 func (*StdNetBind) ParseEndpoint(s string) (wg.Endpoint, error) { 136 e, err := netip.ParseAddrPort(s) 137 if err != nil { 138 return nil, err 139 } 140 return &wg.StdNetEndpoint{ 141 AddrPort: e, 142 }, nil 143 } 144 145 func (s *StdNetBind) UpdateControlFns(controlFns []func(network, address string, c syscall.RawConn) error) { 146 s.controlFns = controlFns 147 } 148 149 func NewStdNetBind( 150 controlFns []func(network, address string, c syscall.RawConn) error, 151 interfaceName string, 152 reserved []byte, 153 ) wg.Bind { 154 return &StdNetBind{ 155 udpAddrPool: sync.Pool{ 156 New: func() any { 157 return &net.UDPAddr{ 158 IP: make([]byte, 16), 159 } 160 }, 161 }, 162 163 msgsPool: sync.Pool{ 164 New: func() any { 165 // ipv6.Message and ipv4.Message are interchangeable as they are 166 // both aliases for x/net/internal/socket.Message. 167 msgs := make([]ipv6.Message, wg.IdealBatchSize) 168 for i := range msgs { 169 msgs[i].Buffers = make(net.Buffers, 1) 170 msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) 171 } 172 return &msgs 173 }, 174 }, 175 176 controlFns: controlFns, 177 interfaceName: interfaceName, 178 reserved: reserved, 179 } 180 } 181 182 func (s *StdNetBind) Open(uport uint16) ([]wg.ReceiveFunc, uint16, error) { 183 s.mu.Lock() 184 defer s.mu.Unlock() 185 186 var err error 187 var tries int 188 189 if s.ipv4 != nil || s.ipv6 != nil { 190 return nil, 0, wg.ErrBindAlreadyOpen 191 } 192 193 // Attempt to open ipv4 and ipv6 listeners on the same port. 194 // If uport is 0, we can retry on failure. 195 again: 196 port := int(uport) 197 var v4conn, v6conn *net.UDPConn 198 var v4pc *ipv4.PacketConn 199 var v6pc *ipv6.PacketConn 200 201 v4conn, port, err = s.listenNet("udp4", port) 202 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 203 return nil, 0, err 204 } 205 206 // Listen on the same port as we're using for ipv4. 207 v6conn, port, err = s.listenNet("udp6", port) 208 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 209 v4conn.Close() 210 tries++ 211 goto again 212 } 213 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 214 v4conn.Close() 215 return nil, 0, err 216 } 217 var fns []wg.ReceiveFunc 218 if v4conn != nil { 219 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) 220 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 221 v4pc = ipv4.NewPacketConn(v4conn) 222 s.ipv4PC = v4pc 223 } 224 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) 225 s.ipv4 = v4conn 226 } 227 if v6conn != nil { 228 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) 229 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 230 v6pc = ipv6.NewPacketConn(v6conn) 231 s.ipv6PC = v6pc 232 } 233 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) 234 s.ipv6 = v6conn 235 } 236 if len(fns) == 0 { 237 return nil, 0, syscall.EAFNOSUPPORT 238 } 239 240 return fns, uint16(port), nil 241 } 242 243 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { 244 for i := range *msgs { 245 (*msgs)[i].OOB = (*msgs)[i].OOB[:0] 246 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} 247 } 248 s.msgsPool.Put(msgs) 249 } 250 251 func (s *StdNetBind) getMessages() *[]ipv6.Message { 252 return s.msgsPool.Get().(*[]ipv6.Message) 253 } 254 255 var _ ipv6.Message = ipv4.Message{} // If compilation fails here these are no longer the same underlying type. 256 257 type batchReader interface { 258 ReadBatch([]ipv6.Message, int) (int, error) 259 } 260 261 type batchWriter interface { 262 WriteBatch([]ipv6.Message, int) (int, error) 263 } 264 265 func (s *StdNetBind) receiveIP( 266 br batchReader, 267 conn *net.UDPConn, 268 rxOffload bool, 269 bufs [][]byte, 270 sizes []int, 271 eps []wg.Endpoint, 272 ) (numMsgs int, err error) { 273 msgs := s.getMessages() 274 for i := range bufs { 275 (*msgs)[i].Buffers[0] = bufs[i] 276 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] 277 } 278 defer s.putMessages(msgs) 279 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 280 if rxOffload { 281 readAt := len(*msgs) - (wg.IdealBatchSize / udpSegmentMaxDatagrams) 282 numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) 283 if err != nil { 284 return 0, err 285 } 286 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) 287 if err != nil { 288 return 0, err 289 } 290 } else { 291 numMsgs, err = br.ReadBatch(*msgs, 0) 292 if err != nil { 293 return 0, err 294 } 295 } 296 } else { 297 msg := &(*msgs)[0] 298 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) 299 if err != nil { 300 return 0, err 301 } 302 numMsgs = 1 303 } 304 for i := 0; i < numMsgs; i++ { 305 msg := &(*msgs)[i] 306 sizes[i] = msg.N 307 if sizes[i] == 0 { 308 continue 309 } 310 addrPort := msg.Addr.(*net.UDPAddr).AddrPort() 311 ep := &wg.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation 312 getSrcFromControl(msg.OOB[:msg.NN], ep) 313 eps[i] = ep 314 s.resetReserved(msg.Buffers[0]) 315 } 316 return numMsgs, nil 317 } 318 319 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wg.ReceiveFunc { 320 return func(bufs [][]byte, sizes []int, eps []wg.Endpoint) (n int, err error) { 321 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 322 } 323 } 324 325 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) wg.ReceiveFunc { 326 return func(bufs [][]byte, sizes []int, eps []wg.Endpoint) (n int, err error) { 327 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) 328 } 329 } 330 331 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and 332 // rename the IdealBatchSize constant to BatchSize. 333 func (s *StdNetBind) BatchSize() int { 334 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 335 return wg.IdealBatchSize 336 } 337 return 1 338 } 339 340 func (s *StdNetBind) Close() error { 341 s.mu.Lock() 342 defer s.mu.Unlock() 343 344 var err1, err2 error 345 if s.ipv4 != nil { 346 err1 = s.ipv4.Close() 347 s.ipv4 = nil 348 s.ipv4PC = nil 349 } 350 if s.ipv6 != nil { 351 err2 = s.ipv6.Close() 352 s.ipv6 = nil 353 s.ipv6PC = nil 354 } 355 s.blackhole4 = false 356 s.blackhole6 = false 357 s.ipv4TxOffload = false 358 s.ipv4RxOffload = false 359 s.ipv6TxOffload = false 360 s.ipv6RxOffload = false 361 if err1 != nil { 362 return err1 363 } 364 return err2 365 } 366 367 func (s *StdNetBind) Send(bufs [][]byte, endpoint wg.Endpoint) error { 368 s.mu.Lock() 369 blackhole := s.blackhole4 370 conn := s.ipv4 371 offload := s.ipv4TxOffload 372 br := batchWriter(s.ipv4PC) 373 is6 := false 374 if endpoint.DstIP().Is6() { 375 blackhole = s.blackhole6 376 conn = s.ipv6 377 br = s.ipv6PC 378 is6 = true 379 offload = s.ipv6TxOffload 380 } 381 s.mu.Unlock() 382 383 if blackhole { 384 return nil 385 } 386 if conn == nil { 387 return syscall.EAFNOSUPPORT 388 } 389 390 for i := range bufs { 391 s.setReserved(bufs[i]) 392 } 393 394 msgs := s.getMessages() 395 defer s.putMessages(msgs) 396 ua := s.udpAddrPool.Get().(*net.UDPAddr) 397 defer s.udpAddrPool.Put(ua) 398 if is6 { 399 as16 := endpoint.DstIP().As16() 400 copy(ua.IP, as16[:]) 401 ua.IP = ua.IP[:16] 402 } else { 403 as4 := endpoint.DstIP().As4() 404 copy(ua.IP, as4[:]) 405 ua.IP = ua.IP[:4] 406 } 407 ua.Port = int(endpoint.(*wg.StdNetEndpoint).Port()) 408 var ( 409 retried bool 410 err error 411 ) 412 retry: 413 if offload { 414 n := coalesceMessages(ua, endpoint.(*wg.StdNetEndpoint), bufs, *msgs, setGSOSize) 415 err = s.send(conn, br, (*msgs)[:n]) 416 if err != nil && offload && errShouldDisableUDPGSO(err) { 417 offload = false 418 s.mu.Lock() 419 if is6 { 420 s.ipv6TxOffload = false 421 } else { 422 s.ipv4TxOffload = false 423 } 424 s.mu.Unlock() 425 retried = true 426 goto retry 427 } 428 } else { 429 for i := range bufs { 430 (*msgs)[i].Addr = ua 431 (*msgs)[i].Buffers[0] = bufs[i] 432 setSrcControl(&(*msgs)[i].OOB, endpoint.(*wg.StdNetEndpoint)) 433 } 434 err = s.send(conn, br, (*msgs)[:len(bufs)]) 435 } 436 if retried { 437 return wg.ErrUDPGSODisabled{RetryErr: fmt.Errorf("disabled UDP GSO on %s, %w", conn.LocalAddr().String(), err)} 438 } 439 return err 440 } 441 442 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { 443 var ( 444 n int 445 err error 446 start int 447 ) 448 if runtime.GOOS == "linux" || runtime.GOOS == "android" { 449 for { 450 n, err = pc.WriteBatch(msgs[start:], 0) 451 if err != nil || n == len(msgs[start:]) { 452 break 453 } 454 start += n 455 } 456 } else { 457 for _, msg := range msgs { 458 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) 459 if err != nil { 460 break 461 } 462 } 463 } 464 return err 465 }