github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_system.go (about) 1 package tun 2 3 import ( 4 "context" 5 "net" 6 "net/netip" 7 "syscall" 8 "time" 9 10 "github.com/metacubex/sing-tun/internal/clashtcpip" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/control" 14 E "github.com/sagernet/sing/common/exceptions" 15 "github.com/sagernet/sing/common/logger" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/udpnat" 19 ) 20 21 type System struct { 22 ctx context.Context 23 tun Tun 24 tunName string 25 mtu int 26 handler Handler 27 logger logger.Logger 28 inet4Prefixes []netip.Prefix 29 inet6Prefixes []netip.Prefix 30 inet4ServerAddress netip.Addr 31 inet4Address netip.Addr 32 inet6ServerAddress netip.Addr 33 inet6Address netip.Addr 34 broadcastAddr netip.Addr 35 udpTimeout int64 36 tcpListener net.Listener 37 tcpListener6 net.Listener 38 tcpPort uint16 39 tcpPort6 uint16 40 tcpNat *TCPNat 41 udpNat *udpnat.Service[netip.AddrPort] 42 bindInterface bool 43 interfaceFinder control.InterfaceFinder 44 enforceBind bool 45 frontHeadroom int 46 txChecksumOffload bool 47 } 48 49 type Session struct { 50 SourceAddress netip.Addr 51 DestinationAddress netip.Addr 52 SourcePort uint16 53 DestinationPort uint16 54 } 55 56 func NewSystem(options StackOptions) (Stack, error) { 57 stack := &System{ 58 ctx: options.Context, 59 tun: options.Tun, 60 tunName: options.TunOptions.Name, 61 mtu: int(options.TunOptions.MTU), 62 udpTimeout: options.UDPTimeout, 63 handler: options.Handler, 64 logger: options.Logger, 65 inet4Prefixes: options.TunOptions.Inet4Address, 66 inet6Prefixes: options.TunOptions.Inet6Address, 67 broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), 68 bindInterface: options.ForwarderBindInterface, 69 interfaceFinder: options.InterfaceFinder, 70 enforceBind: options.EnforceBindInterface, 71 } 72 if len(options.TunOptions.Inet4Address) > 0 { 73 if options.TunOptions.Inet4Address[0].Bits() == 32 { 74 return nil, E.New("need one more IPv4 address in first prefix for system stack") 75 } 76 stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr() 77 stack.inet4Address = stack.inet4ServerAddress.Next() 78 } 79 if len(options.TunOptions.Inet6Address) > 0 { 80 if options.TunOptions.Inet6Address[0].Bits() == 128 { 81 return nil, E.New("need one more IPv6 address in first prefix for system stack") 82 } 83 stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr() 84 stack.inet6Address = stack.inet6ServerAddress.Next() 85 } 86 if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { 87 return nil, E.New("missing interface address") 88 } 89 return stack, nil 90 } 91 92 func (s *System) Close() error { 93 return common.Close( 94 s.tcpListener, 95 s.tcpListener6, 96 ) 97 } 98 99 func (s *System) Start() error { 100 err := s.start() 101 if err != nil { 102 return err 103 } 104 go s.tunLoop() 105 return nil 106 } 107 108 func (s *System) start() error { 109 err := fixWindowsFirewall() 110 if err != nil { 111 return E.Cause(err, "fix windows firewall for system stack") 112 } 113 var listener net.ListenConfig 114 if s.bindInterface || s.enforceBind { 115 listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error { 116 bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true) 117 if bindErr != nil { 118 s.logger.Warn("bind forwarder to interface: ", bindErr) 119 } 120 if s.enforceBind { 121 return bindErr 122 } 123 return nil 124 }) 125 } 126 if s.inet4Address.IsValid() { 127 address := net.JoinHostPort(s.inet4ServerAddress.String(), "0") 128 if s.enforceBind { 129 address = "0.0.0.0:0" 130 } 131 tcpListener, err := listener.Listen(s.ctx, "tcp4", address) 132 if err != nil { 133 return err 134 } 135 s.tcpListener = tcpListener 136 s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port 137 go s.acceptLoop(tcpListener) 138 } 139 if s.inet6Address.IsValid() { 140 address := net.JoinHostPort(s.inet6ServerAddress.String(), "0") 141 if s.enforceBind { 142 address = "[:]:0" 143 } 144 tcpListener, err := listener.Listen(s.ctx, "tcp6", address) 145 if err != nil { 146 return err 147 } 148 s.tcpListener6 = tcpListener 149 s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port 150 go s.acceptLoop(tcpListener) 151 } 152 s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) 153 s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) 154 return nil 155 } 156 157 func (s *System) tunLoop() { 158 if winTun, isWinTun := s.tun.(WinTun); isWinTun { 159 s.wintunLoop(winTun) 160 return 161 } 162 if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN { 163 s.frontHeadroom = linuxTUN.FrontHeadroom() 164 s.txChecksumOffload = linuxTUN.TXChecksumOffload() 165 batchSize := linuxTUN.BatchSize() 166 if batchSize > 1 { 167 s.batchLoop(linuxTUN, batchSize) 168 return 169 } 170 } 171 packetBuffer := make([]byte, s.mtu+PacketOffset) 172 for { 173 n, err := s.tun.Read(packetBuffer) 174 if err != nil { 175 if E.IsClosed(err) { 176 return 177 } 178 s.logger.Error(E.Cause(err, "read packet")) 179 } 180 if n < clashtcpip.IPv4PacketMinLength { 181 continue 182 } 183 rawPacket := packetBuffer[:n] 184 packet := packetBuffer[PacketOffset:n] 185 if s.processPacket(packet) { 186 _, err = s.tun.Write(rawPacket) 187 if err != nil { 188 s.logger.Trace(E.Cause(err, "write packet")) 189 } 190 } 191 } 192 } 193 194 func (s *System) wintunLoop(winTun WinTun) { 195 for { 196 packet, release, err := winTun.ReadPacket() 197 if err != nil { 198 return 199 } 200 if len(packet) < clashtcpip.IPv4PacketMinLength { 201 release() 202 continue 203 } 204 if s.processPacket(packet) { 205 _, err = winTun.Write(packet) 206 if err != nil { 207 s.logger.Trace(E.Cause(err, "write packet")) 208 } 209 } 210 release() 211 } 212 } 213 214 func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { 215 packetBuffers := make([][]byte, batchSize) 216 writeBuffers := make([][]byte, batchSize) 217 packetSizes := make([]int, batchSize) 218 for i := range packetBuffers { 219 packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom) 220 } 221 for { 222 n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes) 223 if err != nil { 224 if E.IsClosed(err) { 225 return 226 } 227 s.logger.Error(E.Cause(err, "batch read packet")) 228 } 229 if n == 0 { 230 continue 231 } 232 for i := 0; i < n; i++ { 233 packetSize := packetSizes[i] 234 if packetSize < clashtcpip.IPv4PacketMinLength { 235 continue 236 } 237 packetBuffer := packetBuffers[i] 238 packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize] 239 if s.processPacket(packet) { 240 writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize]) 241 } 242 } 243 if len(writeBuffers) > 0 { 244 err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) 245 if err != nil { 246 s.logger.Trace(E.Cause(err, "batch write packet")) 247 } 248 writeBuffers = writeBuffers[:0] 249 } 250 } 251 } 252 253 func (s *System) processPacket(packet []byte) bool { 254 var ( 255 writeBack bool 256 err error 257 ) 258 switch ipVersion := packet[0] >> 4; ipVersion { 259 case 4: 260 writeBack, err = s.processIPv4(packet) 261 case 6: 262 writeBack, err = s.processIPv6(packet) 263 default: 264 err = E.New("ip: unknown version: ", ipVersion) 265 } 266 if err != nil { 267 s.logger.Trace(err) 268 return false 269 } 270 return writeBack 271 } 272 273 func (s *System) acceptLoop(listener net.Listener) { 274 for { 275 conn, err := listener.Accept() 276 if err != nil { 277 return 278 } 279 connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port 280 session := s.tcpNat.LookupBack(connPort) 281 if session == nil { 282 s.logger.Trace(E.New("unknown session with port ", connPort)) 283 continue 284 } 285 destination := M.SocksaddrFromNetIP(session.Destination) 286 if destination.Addr.Is4() { 287 for _, prefix := range s.inet4Prefixes { 288 if prefix.Contains(destination.Addr) { 289 destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) 290 break 291 } 292 } 293 } else { 294 for _, prefix := range s.inet6Prefixes { 295 if prefix.Contains(destination.Addr) { 296 destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) 297 break 298 } 299 } 300 } 301 go func() { 302 _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ 303 Source: M.SocksaddrFromNetIP(session.Source), 304 Destination: destination, 305 }) 306 if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { 307 _ = tcpConn.SetLinger(0) 308 } 309 _ = conn.Close() 310 }() 311 } 312 } 313 314 func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { 315 writeBack = true 316 destination := packet.DestinationIP() 317 if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { 318 return 319 } 320 switch packet.Protocol() { 321 case clashtcpip.TCP: 322 err = s.processIPv4TCP(packet, packet.Payload()) 323 case clashtcpip.UDP: 324 writeBack = false 325 err = s.processIPv4UDP(packet, packet.Payload()) 326 case clashtcpip.ICMP: 327 err = s.processIPv4ICMP(packet, packet.Payload()) 328 } 329 return 330 } 331 332 func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { 333 writeBack = true 334 if !packet.DestinationIP().IsGlobalUnicast() { 335 return 336 } 337 switch packet.Protocol() { 338 case clashtcpip.TCP: 339 err = s.processIPv6TCP(packet, packet.Payload()) 340 case clashtcpip.UDP: 341 writeBack = false 342 err = s.processIPv6UDP(packet, packet.Payload()) 343 case clashtcpip.ICMPv6: 344 err = s.processIPv6ICMP(packet, packet.Payload()) 345 } 346 return 347 } 348 349 func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { 350 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 351 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 352 if !destination.Addr().IsGlobalUnicast() { 353 return nil 354 } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { 355 session := s.tcpNat.LookupBack(destination.Port()) 356 if session == nil { 357 return E.New("ipv4: tcp: session not found: ", destination.Port()) 358 } 359 packet.SetSourceIP(session.Destination.Addr()) 360 header.SetSourcePort(session.Destination.Port()) 361 packet.SetDestinationIP(session.Source.Addr()) 362 header.SetDestinationPort(session.Source.Port()) 363 } else { 364 natPort := s.tcpNat.Lookup(source, destination) 365 packet.SetSourceIP(s.inet4Address) 366 header.SetSourcePort(natPort) 367 packet.SetDestinationIP(s.inet4ServerAddress) 368 header.SetDestinationPort(s.tcpPort) 369 } 370 if !s.txChecksumOffload { 371 header.ResetChecksum(packet.PseudoSum()) 372 packet.ResetChecksum() 373 } else { 374 header.OffloadChecksum() 375 packet.ResetChecksum() 376 } 377 return nil 378 } 379 380 func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { 381 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 382 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 383 if !destination.Addr().IsGlobalUnicast() { 384 return nil 385 } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { 386 session := s.tcpNat.LookupBack(destination.Port()) 387 if session == nil { 388 return E.New("ipv6: tcp: session not found: ", destination.Port()) 389 } 390 packet.SetSourceIP(session.Destination.Addr()) 391 header.SetSourcePort(session.Destination.Port()) 392 packet.SetDestinationIP(session.Source.Addr()) 393 header.SetDestinationPort(session.Source.Port()) 394 } else { 395 natPort := s.tcpNat.Lookup(source, destination) 396 packet.SetSourceIP(s.inet6Address) 397 header.SetSourcePort(natPort) 398 packet.SetDestinationIP(s.inet6ServerAddress) 399 header.SetDestinationPort(s.tcpPort6) 400 } 401 if !s.txChecksumOffload { 402 header.ResetChecksum(packet.PseudoSum()) 403 } else { 404 header.OffloadChecksum() 405 } 406 return nil 407 } 408 409 func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { 410 if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { 411 return E.New("ipv4: fragment dropped") 412 } 413 if packet.FragmentOffset() != 0 { 414 return E.New("ipv4: udp: fragment dropped") 415 } 416 if !header.Valid() { 417 return E.New("ipv4: udp: invalid packet") 418 } 419 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 420 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 421 if !destination.Addr().IsGlobalUnicast() { 422 return nil 423 } 424 data := buf.As(header.Payload()) 425 if data.Len() == 0 { 426 return nil 427 } 428 metadata := M.Metadata{ 429 Source: M.SocksaddrFromNetIP(source), 430 Destination: M.SocksaddrFromNetIP(destination), 431 } 432 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 433 headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize 434 headerCopy := make([]byte, headerLen) 435 copy(headerCopy, packet[:headerLen]) 436 return &systemUDPPacketWriter4{ 437 s.tun, 438 s.frontHeadroom + PacketOffset, 439 headerCopy, 440 source, 441 s.txChecksumOffload, 442 } 443 }) 444 return nil 445 } 446 447 func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { 448 if !header.Valid() { 449 return E.New("ipv6: udp: invalid packet") 450 } 451 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 452 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 453 if !destination.Addr().IsGlobalUnicast() { 454 return nil 455 } 456 data := buf.As(header.Payload()) 457 if data.Len() == 0 { 458 return nil 459 } 460 metadata := M.Metadata{ 461 Source: M.SocksaddrFromNetIP(source), 462 Destination: M.SocksaddrFromNetIP(destination), 463 } 464 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 465 headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize 466 headerCopy := make([]byte, headerLen) 467 copy(headerCopy, packet[:headerLen]) 468 return &systemUDPPacketWriter6{ 469 s.tun, 470 s.frontHeadroom + PacketOffset, 471 headerCopy, 472 source, 473 s.txChecksumOffload, 474 } 475 }) 476 return nil 477 } 478 479 func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { 480 if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { 481 return nil 482 } 483 header.SetType(clashtcpip.ICMPTypePingResponse) 484 sourceAddress := packet.SourceIP() 485 packet.SetSourceIP(packet.DestinationIP()) 486 packet.SetDestinationIP(sourceAddress) 487 header.ResetChecksum() 488 packet.ResetChecksum() 489 return nil 490 } 491 492 func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { 493 if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { 494 return nil 495 } 496 header.SetType(clashtcpip.ICMPv6EchoReply) 497 sourceAddress := packet.SourceIP() 498 packet.SetSourceIP(packet.DestinationIP()) 499 packet.SetDestinationIP(sourceAddress) 500 header.ResetChecksum(packet.PseudoSum()) 501 packet.ResetChecksum() 502 return nil 503 } 504 505 type systemUDPPacketWriter4 struct { 506 tun Tun 507 frontHeadroom int 508 header []byte 509 source netip.AddrPort 510 txChecksumOffload bool 511 } 512 513 func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 514 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 515 defer newPacket.Release() 516 newPacket.Resize(w.frontHeadroom, 0) 517 newPacket.Write(w.header) 518 newPacket.Write(buffer.Bytes()) 519 ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) 520 ipHdr.SetTotalLength(uint16(newPacket.Len())) 521 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 522 ipHdr.SetSourceIP(destination.Addr) 523 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 524 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 525 udpHdr.SetSourcePort(destination.Port) 526 udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) 527 if !w.txChecksumOffload { 528 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 529 ipHdr.ResetChecksum() 530 } else { 531 udpHdr.OffloadChecksum() 532 ipHdr.ResetChecksum() 533 } 534 if PacketOffset > 0 { 535 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET 536 } else { 537 newPacket.Advance(-w.frontHeadroom) 538 } 539 return common.Error(w.tun.Write(newPacket.Bytes())) 540 } 541 542 type systemUDPPacketWriter6 struct { 543 tun Tun 544 frontHeadroom int 545 header []byte 546 source netip.AddrPort 547 txChecksumOffload bool 548 } 549 550 func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 551 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 552 defer newPacket.Release() 553 newPacket.Resize(w.frontHeadroom, 0) 554 newPacket.Write(w.header) 555 newPacket.Write(buffer.Bytes()) 556 ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) 557 udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len()) 558 ipHdr.SetPayloadLength(udpLen) 559 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 560 ipHdr.SetSourceIP(destination.Addr) 561 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 562 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 563 udpHdr.SetSourcePort(destination.Port) 564 udpHdr.SetLength(udpLen) 565 if !w.txChecksumOffload { 566 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 567 } else { 568 udpHdr.OffloadChecksum() 569 } 570 if PacketOffset > 0 { 571 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 572 } else { 573 newPacket.Advance(-w.frontHeadroom) 574 } 575 return common.Error(w.tun.Write(newPacket.Bytes())) 576 }