github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_mixed.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"time"
     7  
     8  	"github.com/sagernet/gvisor/pkg/buffer"
     9  	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
    10  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    11  	"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
    12  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    13  	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
    14  	"github.com/sagernet/gvisor/pkg/waiter"
    15  	"github.com/sagernet/sing-tun/internal/clashtcpip"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	"github.com/sagernet/sing/common/canceler"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  )
    21  
    22  type Mixed struct {
    23  	*System
    24  	endpointIndependentNat bool
    25  	stack                  *stack.Stack
    26  	endpoint               *channel.Endpoint
    27  }
    28  
    29  func NewMixed(
    30  	options StackOptions,
    31  ) (Stack, error) {
    32  	system, err := NewSystem(options)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	return &Mixed{
    37  		System:                 system.(*System),
    38  		endpointIndependentNat: options.EndpointIndependentNat,
    39  	}, nil
    40  }
    41  
    42  func (m *Mixed) Start() error {
    43  	err := m.System.start()
    44  	if err != nil {
    45  		return err
    46  	}
    47  	endpoint := channel.New(1024, uint32(m.mtu), "")
    48  	ipStack, err := newGVisorStack(endpoint)
    49  	if err != nil {
    50  		return err
    51  	}
    52  	if !m.endpointIndependentNat {
    53  		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
    54  			var wq waiter.Queue
    55  			endpoint, err := request.CreateEndpoint(&wq)
    56  			if err != nil {
    57  				return
    58  			}
    59  			udpConn := gonet.NewUDPConn(&wq, endpoint)
    60  			lAddr := udpConn.RemoteAddr()
    61  			rAddr := udpConn.LocalAddr()
    62  			if lAddr == nil || rAddr == nil {
    63  				endpoint.Abort()
    64  				return
    65  			}
    66  			gConn := &gUDPConn{UDPConn: udpConn}
    67  			go func() {
    68  				var metadata M.Metadata
    69  				metadata.Source = M.SocksaddrFromNet(lAddr)
    70  				metadata.Destination = M.SocksaddrFromNet(rAddr)
    71  				ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second)
    72  				hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
    73  				if hErr != nil {
    74  					endpoint.Abort()
    75  				}
    76  			}()
    77  		})
    78  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
    79  	} else {
    80  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
    81  	}
    82  	m.stack = ipStack
    83  	m.endpoint = endpoint
    84  	go m.tunLoop()
    85  	go m.packetLoop()
    86  	return nil
    87  }
    88  
    89  func (m *Mixed) tunLoop() {
    90  	if winTun, isWinTun := m.tun.(WinTun); isWinTun {
    91  		m.wintunLoop(winTun)
    92  		return
    93  	}
    94  	if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
    95  		m.frontHeadroom = linuxTUN.FrontHeadroom()
    96  		m.txChecksumOffload = linuxTUN.TXChecksumOffload()
    97  		batchSize := linuxTUN.BatchSize()
    98  		if batchSize > 1 {
    99  			m.batchLoop(linuxTUN, batchSize)
   100  			return
   101  		}
   102  	}
   103  	packetBuffer := make([]byte, m.mtu+PacketOffset)
   104  	for {
   105  		n, err := m.tun.Read(packetBuffer)
   106  		if err != nil {
   107  			if E.IsClosed(err) {
   108  				return
   109  			}
   110  			m.logger.Error(E.Cause(err, "read packet"))
   111  		}
   112  		if n < clashtcpip.IPv4PacketMinLength {
   113  			continue
   114  		}
   115  		rawPacket := packetBuffer[:n]
   116  		packet := packetBuffer[PacketOffset:n]
   117  		if m.processPacket(packet) {
   118  			_, err = m.tun.Write(rawPacket)
   119  			if err != nil {
   120  				m.logger.Trace(E.Cause(err, "write packet"))
   121  			}
   122  		}
   123  	}
   124  }
   125  
   126  func (m *Mixed) wintunLoop(winTun WinTun) {
   127  	for {
   128  		packet, release, err := winTun.ReadPacket()
   129  		if err != nil {
   130  			return
   131  		}
   132  		if len(packet) < clashtcpip.IPv4PacketMinLength {
   133  			release()
   134  			continue
   135  		}
   136  		if m.processPacket(packet) {
   137  			_, err = winTun.Write(packet)
   138  			if err != nil {
   139  				m.logger.Trace(E.Cause(err, "write packet"))
   140  			}
   141  		}
   142  		release()
   143  	}
   144  }
   145  
   146  func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
   147  	packetBuffers := make([][]byte, batchSize)
   148  	writeBuffers := make([][]byte, batchSize)
   149  	packetSizes := make([]int, batchSize)
   150  	for i := range packetBuffers {
   151  		packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
   152  	}
   153  	for {
   154  		n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
   155  		if err != nil {
   156  			if E.IsClosed(err) {
   157  				return
   158  			}
   159  			m.logger.Error(E.Cause(err, "batch read packet"))
   160  		}
   161  		if n == 0 {
   162  			continue
   163  		}
   164  		for i := 0; i < n; i++ {
   165  			packetSize := packetSizes[i]
   166  			if packetSize < clashtcpip.IPv4PacketMinLength {
   167  				continue
   168  			}
   169  			packetBuffer := packetBuffers[i]
   170  			packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
   171  			if m.processPacket(packet) {
   172  				writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
   173  			}
   174  		}
   175  		if len(writeBuffers) > 0 {
   176  			err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
   177  			if err != nil {
   178  				m.logger.Trace(E.Cause(err, "batch write packet"))
   179  			}
   180  			writeBuffers = writeBuffers[:0]
   181  		}
   182  	}
   183  }
   184  
   185  func (m *Mixed) processPacket(packet []byte) bool {
   186  	var (
   187  		writeBack bool
   188  		err       error
   189  	)
   190  	switch ipVersion := packet[0] >> 4; ipVersion {
   191  	case 4:
   192  		writeBack, err = m.processIPv4(packet)
   193  	case 6:
   194  		writeBack, err = m.processIPv6(packet)
   195  	default:
   196  		err = E.New("ip: unknown version: ", ipVersion)
   197  	}
   198  	if err != nil {
   199  		m.logger.Trace(err)
   200  		return false
   201  	}
   202  	return writeBack
   203  }
   204  
   205  func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
   206  	writeBack = true
   207  	destination := packet.DestinationIP()
   208  	if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
   209  		return
   210  	}
   211  	switch packet.Protocol() {
   212  	case clashtcpip.TCP:
   213  		err = m.processIPv4TCP(packet, packet.Payload())
   214  	case clashtcpip.UDP:
   215  		writeBack = false
   216  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   217  			Payload:           buffer.MakeWithData(packet),
   218  			IsForwardedPacket: true,
   219  		})
   220  		m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
   221  		pkt.DecRef()
   222  		return
   223  	case clashtcpip.ICMP:
   224  		err = m.processIPv4ICMP(packet, packet.Payload())
   225  	}
   226  	return
   227  }
   228  
   229  func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
   230  	writeBack = true
   231  	if !packet.DestinationIP().IsGlobalUnicast() {
   232  		return
   233  	}
   234  	switch packet.Protocol() {
   235  	case clashtcpip.TCP:
   236  		err = m.processIPv6TCP(packet, packet.Payload())
   237  	case clashtcpip.UDP:
   238  		writeBack = false
   239  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   240  			Payload:           buffer.MakeWithData(packet),
   241  			IsForwardedPacket: true,
   242  		})
   243  		m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
   244  		pkt.DecRef()
   245  	case clashtcpip.ICMPv6:
   246  		err = m.processIPv6ICMP(packet, packet.Payload())
   247  	}
   248  	return
   249  }
   250  
   251  func (m *Mixed) packetLoop() {
   252  	for {
   253  		packet := m.endpoint.ReadContext(m.ctx)
   254  		if packet == nil {
   255  			break
   256  		}
   257  		bufio.WriteVectorised(m.tun, packet.AsSlices())
   258  		packet.DecRef()
   259  	}
   260  }
   261  
   262  func (m *Mixed) Close() error {
   263  	m.endpoint.Attach(nil)
   264  	m.stack.Close()
   265  	for _, endpoint := range m.stack.CleanupEndpoints() {
   266  		endpoint.Abort()
   267  	}
   268  	return m.System.Close()
   269  }