github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_mixed.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common/bufio"
     9  	"github.com/sagernet/sing/common/canceler"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  
    13  	"github.com/metacubex/gvisor/pkg/buffer"
    14  	"github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet"
    15  	"github.com/metacubex/gvisor/pkg/tcpip/header"
    16  	"github.com/metacubex/gvisor/pkg/tcpip/link/channel"
    17  	"github.com/metacubex/gvisor/pkg/tcpip/stack"
    18  	"github.com/metacubex/gvisor/pkg/tcpip/transport/udp"
    19  	"github.com/metacubex/gvisor/pkg/waiter"
    20  	"github.com/metacubex/sing-tun/internal/clashtcpip"
    21  )
    22  
    23  type Mixed struct {
    24  	*System
    25  	endpointIndependentNat bool
    26  	stack                  *stack.Stack
    27  	endpoint               *channel.Endpoint
    28  }
    29  
    30  func NewMixed(
    31  	options StackOptions,
    32  ) (Stack, error) {
    33  	system, err := NewSystem(options)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	return &Mixed{
    38  		System:                 system.(*System),
    39  		endpointIndependentNat: options.EndpointIndependentNat,
    40  	}, nil
    41  }
    42  
    43  func (m *Mixed) Start() error {
    44  	err := m.System.start()
    45  	if err != nil {
    46  		return err
    47  	}
    48  	endpoint := channel.New(1024, uint32(m.mtu), "")
    49  	ipStack, err := newGVisorStack(endpoint)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	if !m.endpointIndependentNat {
    54  		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
    55  			var wq waiter.Queue
    56  			endpoint, err := request.CreateEndpoint(&wq)
    57  			if err != nil {
    58  				return
    59  			}
    60  			udpConn := gonet.NewUDPConn(&wq, endpoint)
    61  			lAddr := udpConn.RemoteAddr()
    62  			rAddr := udpConn.LocalAddr()
    63  			if lAddr == nil || rAddr == nil {
    64  				endpoint.Abort()
    65  				return
    66  			}
    67  			gConn := &gUDPConn{UDPConn: udpConn}
    68  			go func() {
    69  				var metadata M.Metadata
    70  				metadata.Source = M.SocksaddrFromNet(lAddr)
    71  				metadata.Destination = M.SocksaddrFromNet(rAddr)
    72  				ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second)
    73  				hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
    74  				if hErr != nil {
    75  					endpoint.Abort()
    76  				}
    77  			}()
    78  		})
    79  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
    80  	} else {
    81  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
    82  	}
    83  	m.stack = ipStack
    84  	m.endpoint = endpoint
    85  	go m.tunLoop()
    86  	go m.packetLoop()
    87  	return nil
    88  }
    89  
    90  func (m *Mixed) tunLoop() {
    91  	if winTun, isWinTun := m.tun.(WinTun); isWinTun {
    92  		m.wintunLoop(winTun)
    93  		return
    94  	}
    95  	if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
    96  		m.frontHeadroom = linuxTUN.FrontHeadroom()
    97  		m.txChecksumOffload = linuxTUN.TXChecksumOffload()
    98  		batchSize := linuxTUN.BatchSize()
    99  		if batchSize > 1 {
   100  			m.batchLoop(linuxTUN, batchSize)
   101  			return
   102  		}
   103  	}
   104  	packetBuffer := make([]byte, m.mtu+PacketOffset)
   105  	for {
   106  		n, err := m.tun.Read(packetBuffer)
   107  		if err != nil {
   108  			if E.IsClosed(err) {
   109  				return
   110  			}
   111  			m.logger.Error(E.Cause(err, "read packet"))
   112  		}
   113  		if n < clashtcpip.IPv4PacketMinLength {
   114  			continue
   115  		}
   116  		rawPacket := packetBuffer[:n]
   117  		packet := packetBuffer[PacketOffset:n]
   118  		if m.processPacket(packet) {
   119  			_, err = m.tun.Write(rawPacket)
   120  			if err != nil {
   121  				m.logger.Trace(E.Cause(err, "write packet"))
   122  			}
   123  		}
   124  	}
   125  }
   126  
   127  func (m *Mixed) wintunLoop(winTun WinTun) {
   128  	for {
   129  		packet, release, err := winTun.ReadPacket()
   130  		if err != nil {
   131  			return
   132  		}
   133  		if len(packet) < clashtcpip.IPv4PacketMinLength {
   134  			release()
   135  			continue
   136  		}
   137  		if m.processPacket(packet) {
   138  			_, err = winTun.Write(packet)
   139  			if err != nil {
   140  				m.logger.Trace(E.Cause(err, "write packet"))
   141  			}
   142  		}
   143  		release()
   144  	}
   145  }
   146  
   147  func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
   148  	packetBuffers := make([][]byte, batchSize)
   149  	writeBuffers := make([][]byte, batchSize)
   150  	packetSizes := make([]int, batchSize)
   151  	for i := range packetBuffers {
   152  		packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
   153  	}
   154  	for {
   155  		n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
   156  		if err != nil {
   157  			if E.IsClosed(err) {
   158  				return
   159  			}
   160  			m.logger.Error(E.Cause(err, "batch read packet"))
   161  		}
   162  		if n == 0 {
   163  			continue
   164  		}
   165  		for i := 0; i < n; i++ {
   166  			packetSize := packetSizes[i]
   167  			if packetSize < clashtcpip.IPv4PacketMinLength {
   168  				continue
   169  			}
   170  			packetBuffer := packetBuffers[i]
   171  			packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
   172  			if m.processPacket(packet) {
   173  				writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
   174  			}
   175  		}
   176  		if len(writeBuffers) > 0 {
   177  			err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
   178  			if err != nil {
   179  				m.logger.Trace(E.Cause(err, "batch write packet"))
   180  			}
   181  			writeBuffers = writeBuffers[:0]
   182  		}
   183  	}
   184  }
   185  
   186  func (m *Mixed) processPacket(packet []byte) bool {
   187  	var (
   188  		writeBack bool
   189  		err       error
   190  	)
   191  	switch ipVersion := packet[0] >> 4; ipVersion {
   192  	case 4:
   193  		writeBack, err = m.processIPv4(packet)
   194  	case 6:
   195  		writeBack, err = m.processIPv6(packet)
   196  	default:
   197  		err = E.New("ip: unknown version: ", ipVersion)
   198  	}
   199  	if err != nil {
   200  		m.logger.Trace(err)
   201  		return false
   202  	}
   203  	return writeBack
   204  }
   205  
   206  func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
   207  	writeBack = true
   208  	destination := packet.DestinationIP()
   209  	if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
   210  		return
   211  	}
   212  	switch packet.Protocol() {
   213  	case clashtcpip.TCP:
   214  		err = m.processIPv4TCP(packet, packet.Payload())
   215  	case clashtcpip.UDP:
   216  		writeBack = false
   217  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   218  			Payload:           buffer.MakeWithData(packet),
   219  			IsForwardedPacket: true,
   220  		})
   221  		m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
   222  		pkt.DecRef()
   223  		return
   224  	case clashtcpip.ICMP:
   225  		err = m.processIPv4ICMP(packet, packet.Payload())
   226  	}
   227  	return
   228  }
   229  
   230  func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
   231  	writeBack = true
   232  	if !packet.DestinationIP().IsGlobalUnicast() {
   233  		return
   234  	}
   235  	switch packet.Protocol() {
   236  	case clashtcpip.TCP:
   237  		err = m.processIPv6TCP(packet, packet.Payload())
   238  	case clashtcpip.UDP:
   239  		writeBack = false
   240  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   241  			Payload:           buffer.MakeWithData(packet),
   242  			IsForwardedPacket: true,
   243  		})
   244  		m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
   245  		pkt.DecRef()
   246  	case clashtcpip.ICMPv6:
   247  		err = m.processIPv6ICMP(packet, packet.Payload())
   248  	}
   249  	return
   250  }
   251  
   252  func (m *Mixed) packetLoop() {
   253  	for {
   254  		packet := m.endpoint.ReadContext(m.ctx)
   255  		if packet == nil {
   256  			break
   257  		}
   258  		bufio.WriteVectorised(m.tun, packet.AsSlices())
   259  		packet.DecRef()
   260  	}
   261  }
   262  
   263  func (m *Mixed) Close() error {
   264  	m.endpoint.Attach(nil)
   265  	m.stack.Close()
   266  	for _, endpoint := range m.stack.CleanupEndpoints() {
   267  		endpoint.Abort()
   268  	}
   269  	return m.System.Close()
   270  }