github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_system.go (about) 1 package tun 2 3 import ( 4 "context" 5 "net" 6 "net/netip" 7 "syscall" 8 "time" 9 10 "github.com/sagernet/sing-tun/internal/clashtcpip" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/control" 14 E "github.com/sagernet/sing/common/exceptions" 15 "github.com/sagernet/sing/common/logger" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/udpnat" 19 ) 20 21 var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25") 22 23 type System struct { 24 ctx context.Context 25 tun Tun 26 tunName string 27 mtu int 28 handler Handler 29 logger logger.Logger 30 inet4Prefixes []netip.Prefix 31 inet6Prefixes []netip.Prefix 32 inet4ServerAddress netip.Addr 33 inet4Address netip.Addr 34 inet6ServerAddress netip.Addr 35 inet6Address netip.Addr 36 broadcastAddr netip.Addr 37 udpTimeout int64 38 tcpListener net.Listener 39 tcpListener6 net.Listener 40 tcpPort uint16 41 tcpPort6 uint16 42 tcpNat *TCPNat 43 udpNat *udpnat.Service[netip.AddrPort] 44 bindInterface bool 45 interfaceFinder control.InterfaceFinder 46 frontHeadroom int 47 txChecksumOffload bool 48 } 49 50 type Session struct { 51 SourceAddress netip.Addr 52 DestinationAddress netip.Addr 53 SourcePort uint16 54 DestinationPort uint16 55 } 56 57 func NewSystem(options StackOptions) (Stack, error) { 58 stack := &System{ 59 ctx: options.Context, 60 tun: options.Tun, 61 tunName: options.TunOptions.Name, 62 mtu: int(options.TunOptions.MTU), 63 udpTimeout: options.UDPTimeout, 64 handler: options.Handler, 65 logger: options.Logger, 66 inet4Prefixes: options.TunOptions.Inet4Address, 67 inet6Prefixes: options.TunOptions.Inet6Address, 68 broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), 69 bindInterface: options.ForwarderBindInterface, 70 interfaceFinder: options.InterfaceFinder, 71 } 72 if len(options.TunOptions.Inet4Address) > 0 { 73 if options.TunOptions.Inet4Address[0].Bits() == 32 { 74 return nil, E.New("need one more IPv4 address in first prefix for system stack") 75 } 76 stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr() 77 stack.inet4Address = stack.inet4ServerAddress.Next() 78 } 79 if len(options.TunOptions.Inet6Address) > 0 { 80 if options.TunOptions.Inet6Address[0].Bits() == 128 { 81 return nil, E.New("need one more IPv6 address in first prefix for system stack") 82 } 83 stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr() 84 stack.inet6Address = stack.inet6ServerAddress.Next() 85 } 86 if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { 87 return nil, E.New("missing interface address") 88 } 89 return stack, nil 90 } 91 92 func (s *System) Close() error { 93 return common.Close( 94 s.tcpListener, 95 s.tcpListener6, 96 ) 97 } 98 99 func (s *System) Start() error { 100 err := s.start() 101 if err != nil { 102 return err 103 } 104 go s.tunLoop() 105 return nil 106 } 107 108 func (s *System) start() error { 109 err := fixWindowsFirewall() 110 if err != nil { 111 return E.Cause(err, "fix windows firewall for system stack") 112 } 113 var listener net.ListenConfig 114 if s.bindInterface { 115 listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error { 116 bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true) 117 if bindErr != nil { 118 s.logger.Warn("bind forwarder to interface: ", bindErr) 119 } 120 return nil 121 }) 122 } 123 if s.inet4Address.IsValid() { 124 tcpListener, err := listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0")) 125 if err != nil { 126 return err 127 } 128 s.tcpListener = tcpListener 129 s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port 130 go s.acceptLoop(tcpListener) 131 } 132 if s.inet6Address.IsValid() { 133 tcpListener, err := listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0")) 134 if err != nil { 135 return err 136 } 137 s.tcpListener6 = tcpListener 138 s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port 139 go s.acceptLoop(tcpListener) 140 } 141 s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) 142 s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) 143 return nil 144 } 145 146 func (s *System) tunLoop() { 147 if winTun, isWinTun := s.tun.(WinTun); isWinTun { 148 s.wintunLoop(winTun) 149 return 150 } 151 if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN { 152 s.frontHeadroom = linuxTUN.FrontHeadroom() 153 s.txChecksumOffload = linuxTUN.TXChecksumOffload() 154 batchSize := linuxTUN.BatchSize() 155 if batchSize > 1 { 156 s.batchLoop(linuxTUN, batchSize) 157 return 158 } 159 } 160 packetBuffer := make([]byte, s.mtu+PacketOffset) 161 for { 162 n, err := s.tun.Read(packetBuffer) 163 if err != nil { 164 if E.IsClosed(err) { 165 return 166 } 167 s.logger.Error(E.Cause(err, "read packet")) 168 } 169 if n < clashtcpip.IPv4PacketMinLength { 170 continue 171 } 172 rawPacket := packetBuffer[:n] 173 packet := packetBuffer[PacketOffset:n] 174 if s.processPacket(packet) { 175 _, err = s.tun.Write(rawPacket) 176 if err != nil { 177 s.logger.Trace(E.Cause(err, "write packet")) 178 } 179 } 180 } 181 } 182 183 func (s *System) wintunLoop(winTun WinTun) { 184 for { 185 packet, release, err := winTun.ReadPacket() 186 if err != nil { 187 return 188 } 189 if len(packet) < clashtcpip.IPv4PacketMinLength { 190 release() 191 continue 192 } 193 if s.processPacket(packet) { 194 _, err = winTun.Write(packet) 195 if err != nil { 196 s.logger.Trace(E.Cause(err, "write packet")) 197 } 198 } 199 release() 200 } 201 } 202 203 func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { 204 packetBuffers := make([][]byte, batchSize) 205 writeBuffers := make([][]byte, batchSize) 206 packetSizes := make([]int, batchSize) 207 for i := range packetBuffers { 208 packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom) 209 } 210 for { 211 n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes) 212 if err != nil { 213 if E.IsClosed(err) { 214 return 215 } 216 s.logger.Error(E.Cause(err, "batch read packet")) 217 } 218 if n == 0 { 219 continue 220 } 221 for i := 0; i < n; i++ { 222 packetSize := packetSizes[i] 223 if packetSize < clashtcpip.IPv4PacketMinLength { 224 continue 225 } 226 packetBuffer := packetBuffers[i] 227 packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize] 228 if s.processPacket(packet) { 229 writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize]) 230 } 231 } 232 if len(writeBuffers) > 0 { 233 err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) 234 if err != nil { 235 s.logger.Trace(E.Cause(err, "batch write packet")) 236 } 237 writeBuffers = writeBuffers[:0] 238 } 239 } 240 } 241 242 func (s *System) processPacket(packet []byte) bool { 243 var ( 244 writeBack bool 245 err error 246 ) 247 switch ipVersion := packet[0] >> 4; ipVersion { 248 case 4: 249 writeBack, err = s.processIPv4(packet) 250 case 6: 251 writeBack, err = s.processIPv6(packet) 252 default: 253 err = E.New("ip: unknown version: ", ipVersion) 254 } 255 if err != nil { 256 s.logger.Trace(err) 257 return false 258 } 259 return writeBack 260 } 261 262 func (s *System) acceptLoop(listener net.Listener) { 263 for { 264 conn, err := listener.Accept() 265 if err != nil { 266 return 267 } 268 connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port 269 session := s.tcpNat.LookupBack(connPort) 270 if session == nil { 271 s.logger.Trace(E.New("unknown session with port ", connPort)) 272 continue 273 } 274 destination := M.SocksaddrFromNetIP(session.Destination) 275 if destination.Addr.Is4() { 276 for _, prefix := range s.inet4Prefixes { 277 if prefix.Contains(destination.Addr) { 278 destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) 279 break 280 } 281 } 282 } else { 283 for _, prefix := range s.inet6Prefixes { 284 if prefix.Contains(destination.Addr) { 285 destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) 286 break 287 } 288 } 289 } 290 go func() { 291 _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ 292 Source: M.SocksaddrFromNetIP(session.Source), 293 Destination: destination, 294 }) 295 if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { 296 _ = tcpConn.SetLinger(0) 297 } 298 _ = conn.Close() 299 }() 300 } 301 } 302 303 func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { 304 writeBack = true 305 destination := packet.DestinationIP() 306 if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { 307 return 308 } 309 switch packet.Protocol() { 310 case clashtcpip.TCP: 311 err = s.processIPv4TCP(packet, packet.Payload()) 312 case clashtcpip.UDP: 313 writeBack = false 314 err = s.processIPv4UDP(packet, packet.Payload()) 315 case clashtcpip.ICMP: 316 err = s.processIPv4ICMP(packet, packet.Payload()) 317 } 318 return 319 } 320 321 func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { 322 writeBack = true 323 if !packet.DestinationIP().IsGlobalUnicast() { 324 return 325 } 326 switch packet.Protocol() { 327 case clashtcpip.TCP: 328 err = s.processIPv6TCP(packet, packet.Payload()) 329 case clashtcpip.UDP: 330 writeBack = false 331 err = s.processIPv6UDP(packet, packet.Payload()) 332 case clashtcpip.ICMPv6: 333 err = s.processIPv6ICMP(packet, packet.Payload()) 334 } 335 return 336 } 337 338 func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { 339 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 340 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 341 if !destination.Addr().IsGlobalUnicast() { 342 return nil 343 } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { 344 session := s.tcpNat.LookupBack(destination.Port()) 345 if session == nil { 346 return E.New("ipv4: tcp: session not found: ", destination.Port()) 347 } 348 packet.SetSourceIP(session.Destination.Addr()) 349 header.SetSourcePort(session.Destination.Port()) 350 packet.SetDestinationIP(session.Source.Addr()) 351 header.SetDestinationPort(session.Source.Port()) 352 } else { 353 natPort := s.tcpNat.Lookup(source, destination) 354 packet.SetSourceIP(s.inet4Address) 355 header.SetSourcePort(natPort) 356 packet.SetDestinationIP(s.inet4ServerAddress) 357 header.SetDestinationPort(s.tcpPort) 358 } 359 if !s.txChecksumOffload { 360 header.ResetChecksum(packet.PseudoSum()) 361 packet.ResetChecksum() 362 } else { 363 header.OffloadChecksum() 364 packet.ResetChecksum() 365 } 366 return nil 367 } 368 369 func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { 370 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 371 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 372 if !destination.Addr().IsGlobalUnicast() { 373 return nil 374 } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { 375 session := s.tcpNat.LookupBack(destination.Port()) 376 if session == nil { 377 return E.New("ipv6: tcp: session not found: ", destination.Port()) 378 } 379 packet.SetSourceIP(session.Destination.Addr()) 380 header.SetSourcePort(session.Destination.Port()) 381 packet.SetDestinationIP(session.Source.Addr()) 382 header.SetDestinationPort(session.Source.Port()) 383 } else { 384 natPort := s.tcpNat.Lookup(source, destination) 385 packet.SetSourceIP(s.inet6Address) 386 header.SetSourcePort(natPort) 387 packet.SetDestinationIP(s.inet6ServerAddress) 388 header.SetDestinationPort(s.tcpPort6) 389 } 390 if !s.txChecksumOffload { 391 header.ResetChecksum(packet.PseudoSum()) 392 } else { 393 header.OffloadChecksum() 394 } 395 return nil 396 } 397 398 func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { 399 if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { 400 return E.New("ipv4: fragment dropped") 401 } 402 if packet.FragmentOffset() != 0 { 403 return E.New("ipv4: udp: fragment dropped") 404 } 405 if !header.Valid() { 406 return E.New("ipv4: udp: invalid packet") 407 } 408 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 409 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 410 if !destination.Addr().IsGlobalUnicast() { 411 return nil 412 } 413 data := buf.As(header.Payload()) 414 if data.Len() == 0 { 415 return nil 416 } 417 metadata := M.Metadata{ 418 Source: M.SocksaddrFromNetIP(source), 419 Destination: M.SocksaddrFromNetIP(destination), 420 } 421 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 422 headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize 423 headerCopy := make([]byte, headerLen) 424 copy(headerCopy, packet[:headerLen]) 425 return &systemUDPPacketWriter4{ 426 s.tun, 427 s.frontHeadroom + PacketOffset, 428 headerCopy, 429 source, 430 s.txChecksumOffload, 431 } 432 }) 433 return nil 434 } 435 436 func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { 437 if !header.Valid() { 438 return E.New("ipv6: udp: invalid packet") 439 } 440 source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) 441 destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) 442 if !destination.Addr().IsGlobalUnicast() { 443 return nil 444 } 445 data := buf.As(header.Payload()) 446 if data.Len() == 0 { 447 return nil 448 } 449 metadata := M.Metadata{ 450 Source: M.SocksaddrFromNetIP(source), 451 Destination: M.SocksaddrFromNetIP(destination), 452 } 453 s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { 454 headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize 455 headerCopy := make([]byte, headerLen) 456 copy(headerCopy, packet[:headerLen]) 457 return &systemUDPPacketWriter6{ 458 s.tun, 459 s.frontHeadroom + PacketOffset, 460 headerCopy, 461 source, 462 s.txChecksumOffload, 463 } 464 }) 465 return nil 466 } 467 468 func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { 469 if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { 470 return nil 471 } 472 header.SetType(clashtcpip.ICMPTypePingResponse) 473 sourceAddress := packet.SourceIP() 474 packet.SetSourceIP(packet.DestinationIP()) 475 packet.SetDestinationIP(sourceAddress) 476 header.ResetChecksum() 477 packet.ResetChecksum() 478 return nil 479 } 480 481 func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { 482 if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { 483 return nil 484 } 485 header.SetType(clashtcpip.ICMPv6EchoReply) 486 sourceAddress := packet.SourceIP() 487 packet.SetSourceIP(packet.DestinationIP()) 488 packet.SetDestinationIP(sourceAddress) 489 header.ResetChecksum(packet.PseudoSum()) 490 packet.ResetChecksum() 491 return nil 492 } 493 494 type systemUDPPacketWriter4 struct { 495 tun Tun 496 frontHeadroom int 497 header []byte 498 source netip.AddrPort 499 txChecksumOffload bool 500 } 501 502 func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 503 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 504 defer newPacket.Release() 505 newPacket.Resize(w.frontHeadroom, 0) 506 newPacket.Write(w.header) 507 newPacket.Write(buffer.Bytes()) 508 ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) 509 ipHdr.SetTotalLength(uint16(newPacket.Len())) 510 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 511 ipHdr.SetSourceIP(destination.Addr) 512 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 513 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 514 udpHdr.SetSourcePort(destination.Port) 515 udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) 516 if !w.txChecksumOffload { 517 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 518 ipHdr.ResetChecksum() 519 } else { 520 udpHdr.OffloadChecksum() 521 ipHdr.ResetChecksum() 522 } 523 if PacketOffset > 0 { 524 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET 525 } else { 526 newPacket.Advance(-w.frontHeadroom) 527 } 528 return common.Error(w.tun.Write(newPacket.Bytes())) 529 } 530 531 type systemUDPPacketWriter6 struct { 532 tun Tun 533 frontHeadroom int 534 header []byte 535 source netip.AddrPort 536 txChecksumOffload bool 537 } 538 539 func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 540 newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) 541 defer newPacket.Release() 542 newPacket.Resize(w.frontHeadroom, 0) 543 newPacket.Write(w.header) 544 newPacket.Write(buffer.Bytes()) 545 ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) 546 udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len()) 547 ipHdr.SetPayloadLength(udpLen) 548 ipHdr.SetDestinationIP(ipHdr.SourceIP()) 549 ipHdr.SetSourceIP(destination.Addr) 550 udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) 551 udpHdr.SetDestinationPort(udpHdr.SourcePort()) 552 udpHdr.SetSourcePort(destination.Port) 553 udpHdr.SetLength(udpLen) 554 if !w.txChecksumOffload { 555 udpHdr.ResetChecksum(ipHdr.PseudoSum()) 556 } else { 557 udpHdr.OffloadChecksum() 558 } 559 if PacketOffset > 0 { 560 newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 561 } else { 562 newPacket.Advance(-w.frontHeadroom) 563 } 564 return common.Error(w.tun.Write(newPacket.Bytes())) 565 }