github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_system.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"syscall"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing-tun/internal/clashtcpip"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/control"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	"github.com/sagernet/sing/common/logger"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  )
    20  
    21  var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
    22  
    23  type System struct {
    24  	ctx                context.Context
    25  	tun                Tun
    26  	tunName            string
    27  	mtu                int
    28  	handler            Handler
    29  	logger             logger.Logger
    30  	inet4Prefixes      []netip.Prefix
    31  	inet6Prefixes      []netip.Prefix
    32  	inet4ServerAddress netip.Addr
    33  	inet4Address       netip.Addr
    34  	inet6ServerAddress netip.Addr
    35  	inet6Address       netip.Addr
    36  	broadcastAddr      netip.Addr
    37  	udpTimeout         int64
    38  	tcpListener        net.Listener
    39  	tcpListener6       net.Listener
    40  	tcpPort            uint16
    41  	tcpPort6           uint16
    42  	tcpNat             *TCPNat
    43  	udpNat             *udpnat.Service[netip.AddrPort]
    44  	bindInterface      bool
    45  	interfaceFinder    control.InterfaceFinder
    46  	frontHeadroom      int
    47  	txChecksumOffload  bool
    48  }
    49  
    50  type Session struct {
    51  	SourceAddress      netip.Addr
    52  	DestinationAddress netip.Addr
    53  	SourcePort         uint16
    54  	DestinationPort    uint16
    55  }
    56  
    57  func NewSystem(options StackOptions) (Stack, error) {
    58  	stack := &System{
    59  		ctx:             options.Context,
    60  		tun:             options.Tun,
    61  		tunName:         options.TunOptions.Name,
    62  		mtu:             int(options.TunOptions.MTU),
    63  		udpTimeout:      options.UDPTimeout,
    64  		handler:         options.Handler,
    65  		logger:          options.Logger,
    66  		inet4Prefixes:   options.TunOptions.Inet4Address,
    67  		inet6Prefixes:   options.TunOptions.Inet6Address,
    68  		broadcastAddr:   BroadcastAddr(options.TunOptions.Inet4Address),
    69  		bindInterface:   options.ForwarderBindInterface,
    70  		interfaceFinder: options.InterfaceFinder,
    71  	}
    72  	if len(options.TunOptions.Inet4Address) > 0 {
    73  		if options.TunOptions.Inet4Address[0].Bits() == 32 {
    74  			return nil, E.New("need one more IPv4 address in first prefix for system stack")
    75  		}
    76  		stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
    77  		stack.inet4Address = stack.inet4ServerAddress.Next()
    78  	}
    79  	if len(options.TunOptions.Inet6Address) > 0 {
    80  		if options.TunOptions.Inet6Address[0].Bits() == 128 {
    81  			return nil, E.New("need one more IPv6 address in first prefix for system stack")
    82  		}
    83  		stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
    84  		stack.inet6Address = stack.inet6ServerAddress.Next()
    85  	}
    86  	if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
    87  		return nil, E.New("missing interface address")
    88  	}
    89  	return stack, nil
    90  }
    91  
    92  func (s *System) Close() error {
    93  	return common.Close(
    94  		s.tcpListener,
    95  		s.tcpListener6,
    96  	)
    97  }
    98  
    99  func (s *System) Start() error {
   100  	err := s.start()
   101  	if err != nil {
   102  		return err
   103  	}
   104  	go s.tunLoop()
   105  	return nil
   106  }
   107  
   108  func (s *System) start() error {
   109  	err := fixWindowsFirewall()
   110  	if err != nil {
   111  		return E.Cause(err, "fix windows firewall for system stack")
   112  	}
   113  	var listener net.ListenConfig
   114  	if s.bindInterface {
   115  		listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
   116  			bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true)
   117  			if bindErr != nil {
   118  				s.logger.Warn("bind forwarder to interface: ", bindErr)
   119  			}
   120  			return nil
   121  		})
   122  	}
   123  	if s.inet4Address.IsValid() {
   124  		tcpListener, err := listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
   125  		if err != nil {
   126  			return err
   127  		}
   128  		s.tcpListener = tcpListener
   129  		s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
   130  		go s.acceptLoop(tcpListener)
   131  	}
   132  	if s.inet6Address.IsValid() {
   133  		tcpListener, err := listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
   134  		if err != nil {
   135  			return err
   136  		}
   137  		s.tcpListener6 = tcpListener
   138  		s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
   139  		go s.acceptLoop(tcpListener)
   140  	}
   141  	s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
   142  	s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
   143  	return nil
   144  }
   145  
   146  func (s *System) tunLoop() {
   147  	if winTun, isWinTun := s.tun.(WinTun); isWinTun {
   148  		s.wintunLoop(winTun)
   149  		return
   150  	}
   151  	if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
   152  		s.frontHeadroom = linuxTUN.FrontHeadroom()
   153  		s.txChecksumOffload = linuxTUN.TXChecksumOffload()
   154  		batchSize := linuxTUN.BatchSize()
   155  		if batchSize > 1 {
   156  			s.batchLoop(linuxTUN, batchSize)
   157  			return
   158  		}
   159  	}
   160  	packetBuffer := make([]byte, s.mtu+PacketOffset)
   161  	for {
   162  		n, err := s.tun.Read(packetBuffer)
   163  		if err != nil {
   164  			if E.IsClosed(err) {
   165  				return
   166  			}
   167  			s.logger.Error(E.Cause(err, "read packet"))
   168  		}
   169  		if n < clashtcpip.IPv4PacketMinLength {
   170  			continue
   171  		}
   172  		rawPacket := packetBuffer[:n]
   173  		packet := packetBuffer[PacketOffset:n]
   174  		if s.processPacket(packet) {
   175  			_, err = s.tun.Write(rawPacket)
   176  			if err != nil {
   177  				s.logger.Trace(E.Cause(err, "write packet"))
   178  			}
   179  		}
   180  	}
   181  }
   182  
   183  func (s *System) wintunLoop(winTun WinTun) {
   184  	for {
   185  		packet, release, err := winTun.ReadPacket()
   186  		if err != nil {
   187  			return
   188  		}
   189  		if len(packet) < clashtcpip.IPv4PacketMinLength {
   190  			release()
   191  			continue
   192  		}
   193  		if s.processPacket(packet) {
   194  			_, err = winTun.Write(packet)
   195  			if err != nil {
   196  				s.logger.Trace(E.Cause(err, "write packet"))
   197  			}
   198  		}
   199  		release()
   200  	}
   201  }
   202  
   203  func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
   204  	packetBuffers := make([][]byte, batchSize)
   205  	writeBuffers := make([][]byte, batchSize)
   206  	packetSizes := make([]int, batchSize)
   207  	for i := range packetBuffers {
   208  		packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
   209  	}
   210  	for {
   211  		n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
   212  		if err != nil {
   213  			if E.IsClosed(err) {
   214  				return
   215  			}
   216  			s.logger.Error(E.Cause(err, "batch read packet"))
   217  		}
   218  		if n == 0 {
   219  			continue
   220  		}
   221  		for i := 0; i < n; i++ {
   222  			packetSize := packetSizes[i]
   223  			if packetSize < clashtcpip.IPv4PacketMinLength {
   224  				continue
   225  			}
   226  			packetBuffer := packetBuffers[i]
   227  			packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
   228  			if s.processPacket(packet) {
   229  				writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
   230  			}
   231  		}
   232  		if len(writeBuffers) > 0 {
   233  			err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
   234  			if err != nil {
   235  				s.logger.Trace(E.Cause(err, "batch write packet"))
   236  			}
   237  			writeBuffers = writeBuffers[:0]
   238  		}
   239  	}
   240  }
   241  
   242  func (s *System) processPacket(packet []byte) bool {
   243  	var (
   244  		writeBack bool
   245  		err       error
   246  	)
   247  	switch ipVersion := packet[0] >> 4; ipVersion {
   248  	case 4:
   249  		writeBack, err = s.processIPv4(packet)
   250  	case 6:
   251  		writeBack, err = s.processIPv6(packet)
   252  	default:
   253  		err = E.New("ip: unknown version: ", ipVersion)
   254  	}
   255  	if err != nil {
   256  		s.logger.Trace(err)
   257  		return false
   258  	}
   259  	return writeBack
   260  }
   261  
   262  func (s *System) acceptLoop(listener net.Listener) {
   263  	for {
   264  		conn, err := listener.Accept()
   265  		if err != nil {
   266  			return
   267  		}
   268  		connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port
   269  		session := s.tcpNat.LookupBack(connPort)
   270  		if session == nil {
   271  			s.logger.Trace(E.New("unknown session with port ", connPort))
   272  			continue
   273  		}
   274  		destination := M.SocksaddrFromNetIP(session.Destination)
   275  		if destination.Addr.Is4() {
   276  			for _, prefix := range s.inet4Prefixes {
   277  				if prefix.Contains(destination.Addr) {
   278  					destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
   279  					break
   280  				}
   281  			}
   282  		} else {
   283  			for _, prefix := range s.inet6Prefixes {
   284  				if prefix.Contains(destination.Addr) {
   285  					destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
   286  					break
   287  				}
   288  			}
   289  		}
   290  		go func() {
   291  			_ = s.handler.NewConnection(s.ctx, conn, M.Metadata{
   292  				Source:      M.SocksaddrFromNetIP(session.Source),
   293  				Destination: destination,
   294  			})
   295  			if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn {
   296  				_ = tcpConn.SetLinger(0)
   297  			}
   298  			_ = conn.Close()
   299  		}()
   300  	}
   301  }
   302  
   303  func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
   304  	writeBack = true
   305  	destination := packet.DestinationIP()
   306  	if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
   307  		return
   308  	}
   309  	switch packet.Protocol() {
   310  	case clashtcpip.TCP:
   311  		err = s.processIPv4TCP(packet, packet.Payload())
   312  	case clashtcpip.UDP:
   313  		writeBack = false
   314  		err = s.processIPv4UDP(packet, packet.Payload())
   315  	case clashtcpip.ICMP:
   316  		err = s.processIPv4ICMP(packet, packet.Payload())
   317  	}
   318  	return
   319  }
   320  
   321  func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
   322  	writeBack = true
   323  	if !packet.DestinationIP().IsGlobalUnicast() {
   324  		return
   325  	}
   326  	switch packet.Protocol() {
   327  	case clashtcpip.TCP:
   328  		err = s.processIPv6TCP(packet, packet.Payload())
   329  	case clashtcpip.UDP:
   330  		writeBack = false
   331  		err = s.processIPv6UDP(packet, packet.Payload())
   332  	case clashtcpip.ICMPv6:
   333  		err = s.processIPv6ICMP(packet, packet.Payload())
   334  	}
   335  	return
   336  }
   337  
   338  func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error {
   339  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   340  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   341  	if !destination.Addr().IsGlobalUnicast() {
   342  		return nil
   343  	} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
   344  		session := s.tcpNat.LookupBack(destination.Port())
   345  		if session == nil {
   346  			return E.New("ipv4: tcp: session not found: ", destination.Port())
   347  		}
   348  		packet.SetSourceIP(session.Destination.Addr())
   349  		header.SetSourcePort(session.Destination.Port())
   350  		packet.SetDestinationIP(session.Source.Addr())
   351  		header.SetDestinationPort(session.Source.Port())
   352  	} else {
   353  		natPort := s.tcpNat.Lookup(source, destination)
   354  		packet.SetSourceIP(s.inet4Address)
   355  		header.SetSourcePort(natPort)
   356  		packet.SetDestinationIP(s.inet4ServerAddress)
   357  		header.SetDestinationPort(s.tcpPort)
   358  	}
   359  	if !s.txChecksumOffload {
   360  		header.ResetChecksum(packet.PseudoSum())
   361  		packet.ResetChecksum()
   362  	} else {
   363  		header.OffloadChecksum()
   364  		packet.ResetChecksum()
   365  	}
   366  	return nil
   367  }
   368  
   369  func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error {
   370  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   371  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   372  	if !destination.Addr().IsGlobalUnicast() {
   373  		return nil
   374  	} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
   375  		session := s.tcpNat.LookupBack(destination.Port())
   376  		if session == nil {
   377  			return E.New("ipv6: tcp: session not found: ", destination.Port())
   378  		}
   379  		packet.SetSourceIP(session.Destination.Addr())
   380  		header.SetSourcePort(session.Destination.Port())
   381  		packet.SetDestinationIP(session.Source.Addr())
   382  		header.SetDestinationPort(session.Source.Port())
   383  	} else {
   384  		natPort := s.tcpNat.Lookup(source, destination)
   385  		packet.SetSourceIP(s.inet6Address)
   386  		header.SetSourcePort(natPort)
   387  		packet.SetDestinationIP(s.inet6ServerAddress)
   388  		header.SetDestinationPort(s.tcpPort6)
   389  	}
   390  	if !s.txChecksumOffload {
   391  		header.ResetChecksum(packet.PseudoSum())
   392  	} else {
   393  		header.OffloadChecksum()
   394  	}
   395  	return nil
   396  }
   397  
   398  func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error {
   399  	if packet.Flags()&clashtcpip.FlagMoreFragment != 0 {
   400  		return E.New("ipv4: fragment dropped")
   401  	}
   402  	if packet.FragmentOffset() != 0 {
   403  		return E.New("ipv4: udp: fragment dropped")
   404  	}
   405  	if !header.Valid() {
   406  		return E.New("ipv4: udp: invalid packet")
   407  	}
   408  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   409  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   410  	if !destination.Addr().IsGlobalUnicast() {
   411  		return nil
   412  	}
   413  	data := buf.As(header.Payload())
   414  	if data.Len() == 0 {
   415  		return nil
   416  	}
   417  	metadata := M.Metadata{
   418  		Source:      M.SocksaddrFromNetIP(source),
   419  		Destination: M.SocksaddrFromNetIP(destination),
   420  	}
   421  	s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter {
   422  		headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
   423  		headerCopy := make([]byte, headerLen)
   424  		copy(headerCopy, packet[:headerLen])
   425  		return &systemUDPPacketWriter4{
   426  			s.tun,
   427  			s.frontHeadroom + PacketOffset,
   428  			headerCopy,
   429  			source,
   430  			s.txChecksumOffload,
   431  		}
   432  	})
   433  	return nil
   434  }
   435  
   436  func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
   437  	if !header.Valid() {
   438  		return E.New("ipv6: udp: invalid packet")
   439  	}
   440  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   441  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   442  	if !destination.Addr().IsGlobalUnicast() {
   443  		return nil
   444  	}
   445  	data := buf.As(header.Payload())
   446  	if data.Len() == 0 {
   447  		return nil
   448  	}
   449  	metadata := M.Metadata{
   450  		Source:      M.SocksaddrFromNetIP(source),
   451  		Destination: M.SocksaddrFromNetIP(destination),
   452  	}
   453  	s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter {
   454  		headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
   455  		headerCopy := make([]byte, headerLen)
   456  		copy(headerCopy, packet[:headerLen])
   457  		return &systemUDPPacketWriter6{
   458  			s.tun,
   459  			s.frontHeadroom + PacketOffset,
   460  			headerCopy,
   461  			source,
   462  			s.txChecksumOffload,
   463  		}
   464  	})
   465  	return nil
   466  }
   467  
   468  func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
   469  	if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 {
   470  		return nil
   471  	}
   472  	header.SetType(clashtcpip.ICMPTypePingResponse)
   473  	sourceAddress := packet.SourceIP()
   474  	packet.SetSourceIP(packet.DestinationIP())
   475  	packet.SetDestinationIP(sourceAddress)
   476  	header.ResetChecksum()
   477  	packet.ResetChecksum()
   478  	return nil
   479  }
   480  
   481  func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
   482  	if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 {
   483  		return nil
   484  	}
   485  	header.SetType(clashtcpip.ICMPv6EchoReply)
   486  	sourceAddress := packet.SourceIP()
   487  	packet.SetSourceIP(packet.DestinationIP())
   488  	packet.SetDestinationIP(sourceAddress)
   489  	header.ResetChecksum(packet.PseudoSum())
   490  	packet.ResetChecksum()
   491  	return nil
   492  }
   493  
   494  type systemUDPPacketWriter4 struct {
   495  	tun               Tun
   496  	frontHeadroom     int
   497  	header            []byte
   498  	source            netip.AddrPort
   499  	txChecksumOffload bool
   500  }
   501  
   502  func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   503  	newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
   504  	defer newPacket.Release()
   505  	newPacket.Resize(w.frontHeadroom, 0)
   506  	newPacket.Write(w.header)
   507  	newPacket.Write(buffer.Bytes())
   508  	ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
   509  	ipHdr.SetTotalLength(uint16(newPacket.Len()))
   510  	ipHdr.SetDestinationIP(ipHdr.SourceIP())
   511  	ipHdr.SetSourceIP(destination.Addr)
   512  	udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
   513  	udpHdr.SetDestinationPort(udpHdr.SourcePort())
   514  	udpHdr.SetSourcePort(destination.Port)
   515  	udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
   516  	if !w.txChecksumOffload {
   517  		udpHdr.ResetChecksum(ipHdr.PseudoSum())
   518  		ipHdr.ResetChecksum()
   519  	} else {
   520  		udpHdr.OffloadChecksum()
   521  		ipHdr.ResetChecksum()
   522  	}
   523  	if PacketOffset > 0 {
   524  		newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
   525  	} else {
   526  		newPacket.Advance(-w.frontHeadroom)
   527  	}
   528  	return common.Error(w.tun.Write(newPacket.Bytes()))
   529  }
   530  
   531  type systemUDPPacketWriter6 struct {
   532  	tun               Tun
   533  	frontHeadroom     int
   534  	header            []byte
   535  	source            netip.AddrPort
   536  	txChecksumOffload bool
   537  }
   538  
   539  func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   540  	newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
   541  	defer newPacket.Release()
   542  	newPacket.Resize(w.frontHeadroom, 0)
   543  	newPacket.Write(w.header)
   544  	newPacket.Write(buffer.Bytes())
   545  	ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
   546  	udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len())
   547  	ipHdr.SetPayloadLength(udpLen)
   548  	ipHdr.SetDestinationIP(ipHdr.SourceIP())
   549  	ipHdr.SetSourceIP(destination.Addr)
   550  	udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
   551  	udpHdr.SetDestinationPort(udpHdr.SourcePort())
   552  	udpHdr.SetSourcePort(destination.Port)
   553  	udpHdr.SetLength(udpLen)
   554  	if !w.txChecksumOffload {
   555  		udpHdr.ResetChecksum(ipHdr.PseudoSum())
   556  	} else {
   557  		udpHdr.OffloadChecksum()
   558  	}
   559  	if PacketOffset > 0 {
   560  		newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
   561  	} else {
   562  		newPacket.Advance(-w.frontHeadroom)
   563  	}
   564  	return common.Error(w.tun.Write(newPacket.Bytes()))
   565  }