github.com/apernet/sing-tun@v0.2.6-0.20240323130332-b9f6511036ad/stack_system.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"syscall"
     8  	"time"
     9  
    10  	"github.com/apernet/sing-tun/internal/clashtcpip"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/control"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	"github.com/sagernet/sing/common/logger"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  )
    20  
    21  type System struct {
    22  	ctx                context.Context
    23  	tun                Tun
    24  	tunName            string
    25  	mtu                int
    26  	handler            Handler
    27  	logger             logger.Logger
    28  	inet4Prefixes      []netip.Prefix
    29  	inet6Prefixes      []netip.Prefix
    30  	inet4ServerAddress netip.Addr
    31  	inet4Address       netip.Addr
    32  	inet6ServerAddress netip.Addr
    33  	inet6Address       netip.Addr
    34  	broadcastAddr      netip.Addr
    35  	udpTimeout         int64
    36  	tcpListener        net.Listener
    37  	tcpListener6       net.Listener
    38  	tcpPort            uint16
    39  	tcpPort6           uint16
    40  	tcpNat             *TCPNat
    41  	udpNat             *udpnat.Service[netip.AddrPort]
    42  	bindInterface      bool
    43  	interfaceFinder    control.InterfaceFinder
    44  	frontHeadroom      int
    45  	txChecksumOffload  bool
    46  }
    47  
    48  type Session struct {
    49  	SourceAddress      netip.Addr
    50  	DestinationAddress netip.Addr
    51  	SourcePort         uint16
    52  	DestinationPort    uint16
    53  }
    54  
    55  func NewSystem(options StackOptions) (Stack, error) {
    56  	stack := &System{
    57  		ctx:             options.Context,
    58  		tun:             options.Tun,
    59  		tunName:         options.TunOptions.Name,
    60  		mtu:             int(options.TunOptions.MTU),
    61  		udpTimeout:      options.UDPTimeout,
    62  		handler:         options.Handler,
    63  		logger:          options.Logger,
    64  		inet4Prefixes:   options.TunOptions.Inet4Address,
    65  		inet6Prefixes:   options.TunOptions.Inet6Address,
    66  		broadcastAddr:   BroadcastAddr(options.TunOptions.Inet4Address),
    67  		bindInterface:   options.ForwarderBindInterface,
    68  		interfaceFinder: options.InterfaceFinder,
    69  	}
    70  	if len(options.TunOptions.Inet4Address) > 0 {
    71  		if options.TunOptions.Inet4Address[0].Bits() == 32 {
    72  			return nil, E.New("need one more IPv4 address in first prefix for system stack")
    73  		}
    74  		stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
    75  		stack.inet4Address = stack.inet4ServerAddress.Next()
    76  	}
    77  	if len(options.TunOptions.Inet6Address) > 0 {
    78  		if options.TunOptions.Inet6Address[0].Bits() == 128 {
    79  			return nil, E.New("need one more IPv6 address in first prefix for system stack")
    80  		}
    81  		stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
    82  		stack.inet6Address = stack.inet6ServerAddress.Next()
    83  	}
    84  	if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
    85  		return nil, E.New("missing interface address")
    86  	}
    87  	return stack, nil
    88  }
    89  
    90  func (s *System) Close() error {
    91  	return common.Close(
    92  		s.tcpListener,
    93  		s.tcpListener6,
    94  	)
    95  }
    96  
    97  func (s *System) Start() error {
    98  	err := s.start()
    99  	if err != nil {
   100  		return err
   101  	}
   102  	go s.tunLoop()
   103  	return nil
   104  }
   105  
   106  func (s *System) start() error {
   107  	err := fixWindowsFirewall()
   108  	if err != nil {
   109  		return E.Cause(err, "fix windows firewall for system stack")
   110  	}
   111  	var listener net.ListenConfig
   112  	if s.bindInterface {
   113  		listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
   114  			bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true)
   115  			if bindErr != nil {
   116  				s.logger.Warn("bind forwarder to interface: ", bindErr)
   117  			}
   118  			return nil
   119  		})
   120  	}
   121  	if s.inet4Address.IsValid() {
   122  		tcpListener, err := listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
   123  		if err != nil {
   124  			return err
   125  		}
   126  		s.tcpListener = tcpListener
   127  		s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
   128  		go s.acceptLoop(tcpListener)
   129  	}
   130  	if s.inet6Address.IsValid() {
   131  		tcpListener, err := listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
   132  		if err != nil {
   133  			return err
   134  		}
   135  		s.tcpListener6 = tcpListener
   136  		s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
   137  		go s.acceptLoop(tcpListener)
   138  	}
   139  	s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
   140  	s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
   141  	return nil
   142  }
   143  
   144  func (s *System) tunLoop() error {
   145  	if winTun, isWinTun := s.tun.(WinTun); isWinTun {
   146  		return s.wintunLoop(winTun)
   147  	}
   148  	if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
   149  		s.frontHeadroom = linuxTUN.FrontHeadroom()
   150  		s.txChecksumOffload = linuxTUN.TXChecksumOffload()
   151  		batchSize := linuxTUN.BatchSize()
   152  		if batchSize > 1 {
   153  			return s.batchLoop(linuxTUN, batchSize)
   154  		}
   155  	}
   156  	packetBuffer := make([]byte, s.mtu+PacketOffset)
   157  	for {
   158  		n, err := s.tun.Read(packetBuffer)
   159  		if err != nil {
   160  			if E.IsClosed(err) {
   161  				return err
   162  			}
   163  			s.logger.Error(E.Cause(err, "read packet"))
   164  		}
   165  		if n < clashtcpip.IPv4PacketMinLength {
   166  			continue
   167  		}
   168  		rawPacket := packetBuffer[:n]
   169  		packet := packetBuffer[PacketOffset:n]
   170  		if s.processPacket(packet) {
   171  			_, err = s.tun.Write(rawPacket)
   172  			if err != nil {
   173  				s.logger.Trace(E.Cause(err, "write packet"))
   174  			}
   175  		}
   176  	}
   177  }
   178  
   179  func (s *System) wintunLoop(winTun WinTun) error {
   180  	for {
   181  		packet, release, err := winTun.ReadPacket()
   182  		if err != nil {
   183  			return err
   184  		}
   185  		if len(packet) < clashtcpip.IPv4PacketMinLength {
   186  			release()
   187  			continue
   188  		}
   189  		if s.processPacket(packet) {
   190  			_, err = winTun.Write(packet)
   191  			if err != nil {
   192  				s.logger.Trace(E.Cause(err, "write packet"))
   193  			}
   194  		}
   195  		release()
   196  	}
   197  }
   198  
   199  func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) error {
   200  	packetBuffers := make([][]byte, batchSize)
   201  	writeBuffers := make([][]byte, batchSize)
   202  	packetSizes := make([]int, batchSize)
   203  	for i := range packetBuffers {
   204  		packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
   205  	}
   206  	for {
   207  		n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
   208  		if err != nil {
   209  			if E.IsClosed(err) || E.IsMulti(err, errBadFd) || isErrNotPollable(err) {
   210  				return err
   211  			}
   212  			s.logger.Error(E.Cause(err, "batch read packet"))
   213  		}
   214  		if n == 0 {
   215  			continue
   216  		}
   217  		for i := 0; i < n; i++ {
   218  			packetSize := packetSizes[i]
   219  			if packetSize < clashtcpip.IPv4PacketMinLength {
   220  				continue
   221  			}
   222  			packetBuffer := packetBuffers[i]
   223  			packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
   224  			if s.processPacket(packet) {
   225  				writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
   226  			}
   227  		}
   228  		if len(writeBuffers) > 0 {
   229  			err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
   230  			if err != nil {
   231  				s.logger.Trace(E.Cause(err, "batch write packet"))
   232  			}
   233  			writeBuffers = writeBuffers[:0]
   234  		}
   235  	}
   236  }
   237  
   238  func (s *System) processPacket(packet []byte) bool {
   239  	var (
   240  		writeBack bool
   241  		err       error
   242  	)
   243  	switch ipVersion := packet[0] >> 4; ipVersion {
   244  	case 4:
   245  		writeBack, err = s.processIPv4(packet)
   246  	case 6:
   247  		writeBack, err = s.processIPv6(packet)
   248  	default:
   249  		err = E.New("ip: unknown version: ", ipVersion)
   250  	}
   251  	if err != nil {
   252  		s.logger.Trace(err)
   253  		return false
   254  	}
   255  	return writeBack
   256  }
   257  
   258  func (s *System) acceptLoop(listener net.Listener) {
   259  	for {
   260  		conn, err := listener.Accept()
   261  		if err != nil {
   262  			return
   263  		}
   264  		connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port
   265  		session := s.tcpNat.LookupBack(connPort)
   266  		if session == nil {
   267  			s.logger.Trace(E.New("unknown session with port ", connPort))
   268  			continue
   269  		}
   270  		destination := M.SocksaddrFromNetIP(session.Destination)
   271  		if destination.Addr.Is4() {
   272  			for _, prefix := range s.inet4Prefixes {
   273  				if prefix.Contains(destination.Addr) {
   274  					destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
   275  					break
   276  				}
   277  			}
   278  		} else {
   279  			for _, prefix := range s.inet6Prefixes {
   280  				if prefix.Contains(destination.Addr) {
   281  					destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
   282  					break
   283  				}
   284  			}
   285  		}
   286  		go func() {
   287  			_ = s.handler.NewConnection(s.ctx, conn, M.Metadata{
   288  				Source:      M.SocksaddrFromNetIP(session.Source),
   289  				Destination: destination,
   290  			})
   291  			if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn {
   292  				_ = tcpConn.SetLinger(0)
   293  			}
   294  			_ = conn.Close()
   295  		}()
   296  	}
   297  }
   298  
   299  func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
   300  	writeBack = true
   301  	destination := packet.DestinationIP()
   302  	if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
   303  		return
   304  	}
   305  	switch packet.Protocol() {
   306  	case clashtcpip.TCP:
   307  		err = s.processIPv4TCP(packet, packet.Payload())
   308  	case clashtcpip.UDP:
   309  		writeBack = false
   310  		err = s.processIPv4UDP(packet, packet.Payload())
   311  	case clashtcpip.ICMP:
   312  		err = s.processIPv4ICMP(packet, packet.Payload())
   313  	}
   314  	return
   315  }
   316  
   317  func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
   318  	writeBack = true
   319  	if !packet.DestinationIP().IsGlobalUnicast() {
   320  		return
   321  	}
   322  	switch packet.Protocol() {
   323  	case clashtcpip.TCP:
   324  		err = s.processIPv6TCP(packet, packet.Payload())
   325  	case clashtcpip.UDP:
   326  		writeBack = false
   327  		err = s.processIPv6UDP(packet, packet.Payload())
   328  	case clashtcpip.ICMPv6:
   329  		err = s.processIPv6ICMP(packet, packet.Payload())
   330  	}
   331  	return
   332  }
   333  
   334  func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error {
   335  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   336  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   337  	if !destination.Addr().IsGlobalUnicast() {
   338  		return nil
   339  	} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
   340  		session := s.tcpNat.LookupBack(destination.Port())
   341  		if session == nil {
   342  			return E.New("ipv4: tcp: session not found: ", destination.Port())
   343  		}
   344  		packet.SetSourceIP(session.Destination.Addr())
   345  		header.SetSourcePort(session.Destination.Port())
   346  		packet.SetDestinationIP(session.Source.Addr())
   347  		header.SetDestinationPort(session.Source.Port())
   348  	} else {
   349  		natPort := s.tcpNat.Lookup(source, destination)
   350  		packet.SetSourceIP(s.inet4Address)
   351  		header.SetSourcePort(natPort)
   352  		packet.SetDestinationIP(s.inet4ServerAddress)
   353  		header.SetDestinationPort(s.tcpPort)
   354  	}
   355  	if !s.txChecksumOffload {
   356  		header.ResetChecksum(packet.PseudoSum())
   357  		packet.ResetChecksum()
   358  	} else {
   359  		header.OffloadChecksum()
   360  		packet.ResetChecksum()
   361  	}
   362  	return nil
   363  }
   364  
   365  func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error {
   366  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   367  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   368  	if !destination.Addr().IsGlobalUnicast() {
   369  		return nil
   370  	} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
   371  		session := s.tcpNat.LookupBack(destination.Port())
   372  		if session == nil {
   373  			return E.New("ipv6: tcp: session not found: ", destination.Port())
   374  		}
   375  		packet.SetSourceIP(session.Destination.Addr())
   376  		header.SetSourcePort(session.Destination.Port())
   377  		packet.SetDestinationIP(session.Source.Addr())
   378  		header.SetDestinationPort(session.Source.Port())
   379  	} else {
   380  		natPort := s.tcpNat.Lookup(source, destination)
   381  		packet.SetSourceIP(s.inet6Address)
   382  		header.SetSourcePort(natPort)
   383  		packet.SetDestinationIP(s.inet6ServerAddress)
   384  		header.SetDestinationPort(s.tcpPort6)
   385  	}
   386  	if !s.txChecksumOffload {
   387  		header.ResetChecksum(packet.PseudoSum())
   388  	} else {
   389  		header.OffloadChecksum()
   390  	}
   391  	return nil
   392  }
   393  
   394  func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error {
   395  	if packet.Flags()&clashtcpip.FlagMoreFragment != 0 {
   396  		return E.New("ipv4: fragment dropped")
   397  	}
   398  	if packet.FragmentOffset() != 0 {
   399  		return E.New("ipv4: udp: fragment dropped")
   400  	}
   401  	if !header.Valid() {
   402  		return E.New("ipv4: udp: invalid packet")
   403  	}
   404  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   405  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   406  	if !destination.Addr().IsGlobalUnicast() {
   407  		return nil
   408  	}
   409  	data := buf.As(header.Payload())
   410  	if data.Len() == 0 {
   411  		return nil
   412  	}
   413  	metadata := M.Metadata{
   414  		Source:      M.SocksaddrFromNetIP(source),
   415  		Destination: M.SocksaddrFromNetIP(destination),
   416  	}
   417  	s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter {
   418  		headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
   419  		headerCopy := make([]byte, headerLen)
   420  		copy(headerCopy, packet[:headerLen])
   421  		return &systemUDPPacketWriter4{
   422  			s.tun,
   423  			s.frontHeadroom + PacketOffset,
   424  			headerCopy,
   425  			source,
   426  			s.txChecksumOffload,
   427  		}
   428  	})
   429  	return nil
   430  }
   431  
   432  func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
   433  	if !header.Valid() {
   434  		return E.New("ipv6: udp: invalid packet")
   435  	}
   436  	source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
   437  	destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
   438  	if !destination.Addr().IsGlobalUnicast() {
   439  		return nil
   440  	}
   441  	data := buf.As(header.Payload())
   442  	if data.Len() == 0 {
   443  		return nil
   444  	}
   445  	metadata := M.Metadata{
   446  		Source:      M.SocksaddrFromNetIP(source),
   447  		Destination: M.SocksaddrFromNetIP(destination),
   448  	}
   449  	s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter {
   450  		headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
   451  		headerCopy := make([]byte, headerLen)
   452  		copy(headerCopy, packet[:headerLen])
   453  		return &systemUDPPacketWriter6{
   454  			s.tun,
   455  			s.frontHeadroom + PacketOffset,
   456  			headerCopy,
   457  			source,
   458  			s.txChecksumOffload,
   459  		}
   460  	})
   461  	return nil
   462  }
   463  
   464  func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
   465  	if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 {
   466  		return nil
   467  	}
   468  	header.SetType(clashtcpip.ICMPTypePingResponse)
   469  	sourceAddress := packet.SourceIP()
   470  	packet.SetSourceIP(packet.DestinationIP())
   471  	packet.SetDestinationIP(sourceAddress)
   472  	header.ResetChecksum()
   473  	packet.ResetChecksum()
   474  	return nil
   475  }
   476  
   477  func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
   478  	if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 {
   479  		return nil
   480  	}
   481  	header.SetType(clashtcpip.ICMPv6EchoReply)
   482  	sourceAddress := packet.SourceIP()
   483  	packet.SetSourceIP(packet.DestinationIP())
   484  	packet.SetDestinationIP(sourceAddress)
   485  	header.ResetChecksum(packet.PseudoSum())
   486  	packet.ResetChecksum()
   487  	return nil
   488  }
   489  
   490  type systemUDPPacketWriter4 struct {
   491  	tun               Tun
   492  	frontHeadroom     int
   493  	header            []byte
   494  	source            netip.AddrPort
   495  	txChecksumOffload bool
   496  }
   497  
   498  func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   499  	newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
   500  	defer newPacket.Release()
   501  	newPacket.Resize(w.frontHeadroom, 0)
   502  	newPacket.Write(w.header)
   503  	newPacket.Write(buffer.Bytes())
   504  	ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
   505  	ipHdr.SetTotalLength(uint16(newPacket.Len()))
   506  	ipHdr.SetDestinationIP(ipHdr.SourceIP())
   507  	ipHdr.SetSourceIP(destination.Addr)
   508  	udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
   509  	udpHdr.SetDestinationPort(udpHdr.SourcePort())
   510  	udpHdr.SetSourcePort(destination.Port)
   511  	udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
   512  	if !w.txChecksumOffload {
   513  		udpHdr.ResetChecksum(ipHdr.PseudoSum())
   514  		ipHdr.ResetChecksum()
   515  	} else {
   516  		udpHdr.OffloadChecksum()
   517  		ipHdr.ResetChecksum()
   518  	}
   519  	if PacketOffset > 0 {
   520  		newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
   521  	} else {
   522  		newPacket.Advance(-w.frontHeadroom)
   523  	}
   524  	return common.Error(w.tun.Write(newPacket.Bytes()))
   525  }
   526  
   527  type systemUDPPacketWriter6 struct {
   528  	tun               Tun
   529  	frontHeadroom     int
   530  	header            []byte
   531  	source            netip.AddrPort
   532  	txChecksumOffload bool
   533  }
   534  
   535  func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   536  	newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
   537  	defer newPacket.Release()
   538  	newPacket.Resize(w.frontHeadroom, 0)
   539  	newPacket.Write(w.header)
   540  	newPacket.Write(buffer.Bytes())
   541  	ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
   542  	udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len())
   543  	ipHdr.SetPayloadLength(udpLen)
   544  	ipHdr.SetDestinationIP(ipHdr.SourceIP())
   545  	ipHdr.SetSourceIP(destination.Addr)
   546  	udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
   547  	udpHdr.SetDestinationPort(udpHdr.SourcePort())
   548  	udpHdr.SetSourcePort(destination.Port)
   549  	udpHdr.SetLength(udpLen)
   550  	if !w.txChecksumOffload {
   551  		udpHdr.ResetChecksum(ipHdr.PseudoSum())
   552  	} else {
   553  		udpHdr.OffloadChecksum()
   554  	}
   555  	if PacketOffset > 0 {
   556  		newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
   557  	} else {
   558  		newPacket.Advance(-w.frontHeadroom)
   559  	}
   560  	return common.Error(w.tun.Write(newPacket.Bytes()))
   561  }