github.com/apernet/sing-tun@v0.2.6-0.20240323130332-b9f6511036ad/stack_system.go (about) 1 package tun 2 3 import ( 4 "context" 5 "net" 6 "net/netip" 7 "syscall" 8 "time" 9 10 "github.com/apernet/sing-tun/internal/clashtcpip" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/control" 14 E "github.com/sagernet/sing/common/exceptions" 15 "github.com/sagernet/sing/common/logger" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/udpnat" 19 ) 20 21 type System struct { 22 ctx context.Context 23 tun Tun 24 tunName string 25 mtu int 26 handler Handler 27 logger logger.Logger 28 inet4Prefixes []netip.Prefix 29 inet6Prefixes []netip.Prefix 30 inet4ServerAddress netip.Addr 31 inet4Address netip.Addr 32 inet6ServerAddress netip.Addr 33 inet6Address netip.Addr 34 broadcastAddr netip.Addr 35 udpTimeout int64 36 tcpListener net.Listener 37 tcpListener6 net.Listener 38 tcpPort uint16 39 tcpPort6 uint16 40 tcpNat *TCPNat 41 udpNat *udpnat.Service[netip.AddrPort] 42 bindInterface bool 43 interfaceFinder control.InterfaceFinder 44 frontHeadroom int 45 txChecksumOffload bool 46 } 47 48 type Session struct { 49 SourceAddress netip.Addr 50 DestinationAddress netip.Addr 51 SourcePort uint16 52 DestinationPort uint16 53 } 54 55 func NewSystem(options StackOptions) (Stack, error) { 56 stack := &System{ 57 ctx: options.Context, 58 tun: options.Tun, 59 tunName: options.TunOptions.Name, 60 mtu: int(options.TunOptions.MTU), 61 udpTimeout: options.UDPTimeout, 62 handler: options.Handler, 63 logger: options.Logger, 64 inet4Prefixes: options.TunOptions.Inet4Address, 65 inet6Prefixes: options.TunOptions.Inet6Address, 66 broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), 67 bindInterface: options.ForwarderBindInterface, 68 interfaceFinder: options.InterfaceFinder, 69 } 70 if len(options.TunOptions.Inet4Address) > 0 { 71 if options.TunOptions.Inet4Address[0].Bits() == 32 { 72 return nil, E.New("need one more IPv4 address in first prefix for system stack") 73 } 74 stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr() 75 stack.inet4Address = stack.inet4ServerAddress.Next() 76 } 77 if len(options.TunOptions.Inet6Address) > 0 { 78 if options.TunOptions.Inet6Address[0].Bits() == 128 { 79 return nil, E.New("need one more IPv6 address in first prefix for system stack") 80 } 81 stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr() 82 stack.inet6Address = stack.inet6ServerAddress.Next() 83 } 84 if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { 85 return nil, E.New("missing interface address") 86 } 87 return stack, nil 88 } 89 90 func (s *System) Close() error { 91 return common.Close( 92 s.tcpListener, 93 s.tcpListener6, 94 ) 95 } 96 97 func (s *System) Start() error { 98 err := s.start() 99 if err != nil { 100 return err 101 } 102 go s.tunLoop() 103 return nil 104 } 105 106 func (s *System) start() error { 107 err := fixWindowsFirewall() 108 if err != nil { 109 return E.Cause(err, "fix windows firewall for system stack") 110 } 111 var listener net.ListenConfig 112 if s.bindInterface { 113 listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error { 114 bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true) 115 if bindErr != nil { 116 s.logger.Warn("bind forwarder to interface: ", bindErr) 117 } 118 return nil 119 }) 120 } 121 if s.inet4Address.IsValid() { 122 tcpListener, err := listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0")) 123 if err != nil { 124 return err 125 } 126 s.tcpListener = tcpListener 127 s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port 128 go s.acceptLoop(tcpListener) 129 } 130 if s.inet6Address.IsValid() { 131 tcpListener, err := listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0")) 132 if err != nil { 133 return err 134 } 135 s.tcpListener6 = tcpListener 136 s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port 137 go s.acceptLoop(tcpListener) 138 } 139 s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) 140 s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) 141 return nil 142 } 143 144 func (s *System) tunLoop() error { 145 if winTun, isWinTun := s.tun.(WinTun); isWinTun { 146 return s.wintunLoop(winTun) 147 } 148 if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN { 149 s.frontHeadroom = linuxTUN.FrontHeadroom() 150 s.txChecksumOffload = linuxTUN.TXChecksumOffload() 151 batchSize := linuxTUN.BatchSize() 152 if batchSize > 1 { 153 return s.batchLoop(linuxTUN, batchSize) 154 } 155 } 156 packetBuffer := make([]byte, s.mtu+PacketOffset) 157 for { 158 n, err := s.tun.Read(packetBuffer) 159 if err != nil { 160 if E.IsClosed(err) { 161 return err 162 } 163 s.logger.Error(E.Cause(err, "read packet")) 164 } 165 if n < clashtcpip.IPv4PacketMinLength { 166 continue 167 } 168 rawPacket := packetBuffer[:n] 169 packet := packetBuffer[PacketOffset:n] 170 if s.processPacket(packet) { 171 _, err = s.tun.Write(rawPacket) 172 if err != nil { 173 s.logger.Trace(E.Cause(err, "write packet")) 174 } 175 } 176 } 177 } 178 179 func (s *System) wintunLoop(winTun WinTun) error { 180 for { 181 packet, release, err := winTun.ReadPacket() 182 if err != nil { 183 return err 184 } 185 if len(packet) < clashtcpip.IPv4PacketMinLength { 186 release() 187 continue 188 } 189 if s.processPacket(packet) { 190 _, err = winTun.Write(packet) 191 if err != nil { 192 s.logger.Trace(E.Cause(err, "write packet")) 193 } 194 } 195 release() 196 } 197 } 198 199 func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) error { 200 packetBuffers := make([][]byte, batchSize) 201 writeBuffers := make([][]byte, batchSize) 202 packetSizes := make([]int, batchSize) 203 for i := range packetBuffers { 204 packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom) 205 } 206 for { 207 n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes) 208 if err != nil { 209 if E.IsClosed(err) || E.IsMulti(err, errBadFd) || isErrNotPollable(err) { 210 return err 211 } 212 s.logger.Error(E.Cause(err, "batch read packet")) 213 } 214 if n == 0 { 215 continue 216 } 217 for i := 0; i < n; i++ { 218 packetSize := packetSizes[i] 219 if packetSize < clashtcpip.IPv4PacketMinLength { 220 continue 221 } 222 packetBuffer := packetBuffers[i] 223 packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize] 224 if s.processPacket(packet) { 225 writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize]) 226 } 227 } 228 if len(writeBuffers) > 0 { 229 err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) 230 if err != nil { 231 s.logger.Trace(E.Cause(err, "batch write packet")) 232 } 233 writeBuffers = writeBuffers[:0] 234 } 235 } 236 } 237 238 func (s *System) processPacket(packet []byte) bool { 239 var ( 240 writeBack bool 241 err error 242 ) 243 switch ipVersion := packet[0] >> 4; ipVersion { 244 case 4: 245 writeBack, err = s.processIPv4(packet) 246 case 6: 247 writeBack, err = s.processIPv6(packet) 248 default: 249 err = E.New("ip: unknown version: ", ipVersion) 250 } 251 if err != nil { 252 s.logger.Trace(err) 253 return false 254 } 255 return writeBack 256 } 257 258 func (s *System) acceptLoop(listener net.Listener) { 259 for { 260 conn, err := listener.Accept() 261 if err != nil { 262 return 263 } 264 connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port 265 session := s.tcpNat.LookupBack(connPort) 266 if session == nil { 267 s.logger.Trace(E.New("unknown session with port ", connPort)) 268 continue 269 } 270 destination := M.SocksaddrFromNetIP(session.Destination) 271 if destination.Addr.Is4() { 272 for _, prefix := range s.inet4Prefixes { 273 if prefix.Contains(destination.Addr) { 274 destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) 275 break 276 } 277 } 278 } else { 279 for _, prefix := range s.inet6Prefixes { 280 if prefix.Contains(destination.Addr) { 281 destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) 282 break 283 } 284 } 285 } 286 go func() { 287 _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ 288 Source: M.SocksaddrFromNetIP(session.Source), 289 Destination: destination, 290 }) 291 if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { 292 _ = tcpConn.SetLinger(0) 293 } 294 _ = conn.Close() 295 }() 296 } 297 } 298 299 func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { 300 writeBack = true 301 destination := packet.DestinationIP() 302 if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { 303 return 304 } 305 switch packet.Protocol() { 306 case clashtcpip.TCP: 307 err = s.processIPv4TCP(packet, packet.Payload()) 308 case clashtcpip.UDP: 309 writeBack = false 310 err = s.processIPv4UDP(packet, packet.Payload()) 311 case clashtcpip.ICMP: 312 err = s.processIPv4ICMP(packet, packet.Payload()) 313 } 314 return 315 } 316 317 func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { 318 writeBack = true 319 if !packet.DestinationIP().IsGlobalUnicast() { 320 return 321 } 322 switch packet.Protocol() { 323 case clashtcpip.TCP: 324 err = s.processIPv6TCP(packet, packet.Payload()) 325 case clashtcpip.UDP: 326 writeBack = false 327 err = s.processIPv6UDP(packet, packet.Payload()) 328 case clashtcpip.ICMPv6: 329 err = s.processIPv6ICMP(packet, packet.Payload()) 330 } 331 return 332 } 333 334 func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { 335 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 336 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 337 if !destination.Addr().IsGlobalUnicast() { 338 return nil 339 } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { 340 session := s.tcpNat.LookupBack(destination.Port()) 341 if session == nil { 342 return E.New("ipv4: tcp: session not found: ", destination.Port()) 343 } 344 packet.SetSourceIP(session.Destination.Addr()) 345 header.SetSourcePort(session.Destination.Port()) 346 packet.SetDestinationIP(session.Source.Addr()) 347 header.SetDestinationPort(session.Source.Port()) 348 } else { 349 natPort := s.tcpNat.Lookup(source, destination) 350 packet.SetSourceIP(s.inet4Address) 351 header.SetSourcePort(natPort) 352 packet.SetDestinationIP(s.inet4ServerAddress) 353 header.SetDestinationPort(s.tcpPort) 354 } 355 if !s.txChecksumOffload { 356 header.ResetChecksum(packet.PseudoSum()) 357 packet.ResetChecksum() 358 } else { 359 header.OffloadChecksum() 360 packet.ResetChecksum() 361 } 362 return nil 363 } 364 365 func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { 366 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 367 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 368 if !destination.Addr().IsGlobalUnicast() { 369 return nil 370 } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { 371 session := s.tcpNat.LookupBack(destination.Port()) 372 if session == nil { 373 return E.New("ipv6: tcp: session not found: ", destination.Port()) 374 } 375 packet.SetSourceIP(session.Destination.Addr()) 376 header.SetSourcePort(session.Destination.Port()) 377 packet.SetDestinationIP(session.Source.Addr()) 378 header.SetDestinationPort(session.Source.Port()) 379 } else { 380 natPort := s.tcpNat.Lookup(source, destination) 381 packet.SetSourceIP(s.inet6Address) 382 header.SetSourcePort(natPort) 383 packet.SetDestinationIP(s.inet6ServerAddress) 384 header.SetDestinationPort(s.tcpPort6) 385 } 386 if !s.txChecksumOffload { 387 header.ResetChecksum(packet.PseudoSum()) 388 } else { 389 header.OffloadChecksum() 390 } 391 return nil 392 } 393 394 func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { 395 if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { 396 return E.New("ipv4: fragment dropped") 397 } 398 if packet.FragmentOffset() != 0 { 399 return E.New("ipv4: udp: fragment dropped") 400 } 401 if !header.Valid() { 402 return E.New("ipv4: udp: invalid packet") 403 } 404 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 405 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 406 if !destination.Addr().IsGlobalUnicast() { 407 return nil 408 } 409 data := buf.As(header.Payload()) 410 if data.Len() == 0 { 411 return nil 412 } 413 metadata := M.Metadata{ 414 Source: M.SocksaddrFromNetIP(source), 415 Destination: M.SocksaddrFromNetIP(destination), 416 } 417 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 418 headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize 419 headerCopy := make([]byte, headerLen) 420 copy(headerCopy, packet[:headerLen]) 421 return &systemUDPPacketWriter4{ 422 s.tun, 423 s.frontHeadroom + PacketOffset, 424 headerCopy, 425 source, 426 s.txChecksumOffload, 427 } 428 }) 429 return nil 430 } 431 432 func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { 433 if !header.Valid() { 434 return E.New("ipv6: udp: invalid packet") 435 } 436 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 437 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 438 if !destination.Addr().IsGlobalUnicast() { 439 return nil 440 } 441 data := buf.As(header.Payload()) 442 if data.Len() == 0 { 443 return nil 444 } 445 metadata := M.Metadata{ 446 Source: M.SocksaddrFromNetIP(source), 447 Destination: M.SocksaddrFromNetIP(destination), 448 } 449 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 450 headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize 451 headerCopy := make([]byte, headerLen) 452 copy(headerCopy, packet[:headerLen]) 453 return &systemUDPPacketWriter6{ 454 s.tun, 455 s.frontHeadroom + PacketOffset, 456 headerCopy, 457 source, 458 s.txChecksumOffload, 459 } 460 }) 461 return nil 462 } 463 464 func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { 465 if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { 466 return nil 467 } 468 header.SetType(clashtcpip.ICMPTypePingResponse) 469 sourceAddress := packet.SourceIP() 470 packet.SetSourceIP(packet.DestinationIP()) 471 packet.SetDestinationIP(sourceAddress) 472 header.ResetChecksum() 473 packet.ResetChecksum() 474 return nil 475 } 476 477 func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { 478 if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { 479 return nil 480 } 481 header.SetType(clashtcpip.ICMPv6EchoReply) 482 sourceAddress := packet.SourceIP() 483 packet.SetSourceIP(packet.DestinationIP()) 484 packet.SetDestinationIP(sourceAddress) 485 header.ResetChecksum(packet.PseudoSum()) 486 packet.ResetChecksum() 487 return nil 488 } 489 490 type systemUDPPacketWriter4 struct { 491 tun Tun 492 frontHeadroom int 493 header []byte 494 source netip.AddrPort 495 txChecksumOffload bool 496 } 497 498 func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 499 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 500 defer newPacket.Release() 501 newPacket.Resize(w.frontHeadroom, 0) 502 newPacket.Write(w.header) 503 newPacket.Write(buffer.Bytes()) 504 ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) 505 ipHdr.SetTotalLength(uint16(newPacket.Len())) 506 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 507 ipHdr.SetSourceIP(destination.Addr) 508 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 509 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 510 udpHdr.SetSourcePort(destination.Port) 511 udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) 512 if !w.txChecksumOffload { 513 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 514 ipHdr.ResetChecksum() 515 } else { 516 udpHdr.OffloadChecksum() 517 ipHdr.ResetChecksum() 518 } 519 if PacketOffset > 0 { 520 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET 521 } else { 522 newPacket.Advance(-w.frontHeadroom) 523 } 524 return common.Error(w.tun.Write(newPacket.Bytes())) 525 } 526 527 type systemUDPPacketWriter6 struct { 528 tun Tun 529 frontHeadroom int 530 header []byte 531 source netip.AddrPort 532 txChecksumOffload bool 533 } 534 535 func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 536 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 537 defer newPacket.Release() 538 newPacket.Resize(w.frontHeadroom, 0) 539 newPacket.Write(w.header) 540 newPacket.Write(buffer.Bytes()) 541 ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) 542 udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len()) 543 ipHdr.SetPayloadLength(udpLen) 544 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 545 ipHdr.SetSourceIP(destination.Addr) 546 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 547 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 548 udpHdr.SetSourcePort(destination.Port) 549 udpHdr.SetLength(udpLen) 550 if !w.txChecksumOffload { 551 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 552 } else { 553 udpHdr.OffloadChecksum() 554 } 555 if PacketOffset > 0 { 556 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 557 } else { 558 newPacket.Advance(-w.frontHeadroom) 559 } 560 return common.Error(w.tun.Write(newPacket.Bytes())) 561 }