github.com/MerlinKodo/sing-tun@v0.1.15/stack_mixed.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	"github.com/sagernet/sing/common/canceler"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  
    15  	"github.com/MerlinKodo/gvisor/pkg/buffer"
    16  	"github.com/MerlinKodo/gvisor/pkg/tcpip/adapters/gonet"
    17  	"github.com/MerlinKodo/gvisor/pkg/tcpip/header"
    18  	"github.com/MerlinKodo/gvisor/pkg/tcpip/link/channel"
    19  	"github.com/MerlinKodo/gvisor/pkg/tcpip/stack"
    20  	"github.com/MerlinKodo/gvisor/pkg/tcpip/transport/udp"
    21  	"github.com/MerlinKodo/gvisor/pkg/waiter"
    22  	"github.com/MerlinKodo/sing-tun/internal/clashtcpip"
    23  )
    24  
    25  type Mixed struct {
    26  	*System
    27  	writer                 N.VectorisedWriter
    28  	endpointIndependentNat bool
    29  	stack                  *stack.Stack
    30  	endpoint               *channel.Endpoint
    31  }
    32  
    33  func NewMixed(
    34  	options StackOptions,
    35  ) (Stack, error) {
    36  	system, err := NewSystem(options)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	return &Mixed{
    41  		System:                 system.(*System),
    42  		writer:                 options.Tun.CreateVectorisedWriter(),
    43  		endpointIndependentNat: options.EndpointIndependentNat,
    44  	}, nil
    45  }
    46  
    47  func (m *Mixed) Start() error {
    48  	err := m.System.start()
    49  	if err != nil {
    50  		return err
    51  	}
    52  	endpoint := channel.New(1024, m.mtu, "")
    53  	ipStack, err := newGVisorStack(endpoint)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	if !m.endpointIndependentNat {
    58  		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
    59  			var wq waiter.Queue
    60  			endpoint, err := request.CreateEndpoint(&wq)
    61  			if err != nil {
    62  				return
    63  			}
    64  			udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
    65  			lAddr := udpConn.RemoteAddr()
    66  			rAddr := udpConn.LocalAddr()
    67  			if lAddr == nil || rAddr == nil {
    68  				endpoint.Abort()
    69  				return
    70  			}
    71  			gConn := &gUDPConn{UDPConn: udpConn}
    72  			go func() {
    73  				var metadata M.Metadata
    74  				metadata.Source = M.SocksaddrFromNet(lAddr)
    75  				metadata.Destination = M.SocksaddrFromNet(rAddr)
    76  				ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second)
    77  				hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
    78  				if hErr != nil {
    79  					endpoint.Abort()
    80  				}
    81  			}()
    82  		})
    83  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
    84  	} else {
    85  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
    86  	}
    87  	m.stack = ipStack
    88  	m.endpoint = endpoint
    89  	go m.tunLoop()
    90  	go m.packetLoop()
    91  	return nil
    92  }
    93  
    94  func (m *Mixed) tunLoop() {
    95  	if winTun, isWinTun := m.tun.(WinTun); isWinTun {
    96  		m.wintunLoop(winTun)
    97  		return
    98  	}
    99  	packetBuffer := make([]byte, m.mtu+PacketOffset)
   100  	for {
   101  		n, err := m.tun.Read(packetBuffer)
   102  		if err != nil {
   103  			return
   104  		}
   105  		if n < clashtcpip.IPv4PacketMinLength {
   106  			continue
   107  		}
   108  		packet := packetBuffer[PacketOffset:n]
   109  		switch ipVersion := packet[0] >> 4; ipVersion {
   110  		case 4:
   111  			err = m.processIPv4(packet)
   112  		case 6:
   113  			err = m.processIPv6(packet)
   114  		default:
   115  			err = E.New("ip: unknown version: ", ipVersion)
   116  		}
   117  		if err != nil {
   118  			m.logger.Trace(err)
   119  		}
   120  	}
   121  }
   122  
   123  func (m *Mixed) wintunLoop(winTun WinTun) {
   124  	for {
   125  		packet, release, err := winTun.ReadPacket()
   126  		if err != nil {
   127  			return
   128  		}
   129  		if len(packet) < clashtcpip.IPv4PacketMinLength {
   130  			release()
   131  			continue
   132  		}
   133  		switch ipVersion := packet[0] >> 4; ipVersion {
   134  		case 4:
   135  			err = m.processIPv4(packet)
   136  		case 6:
   137  			err = m.processIPv6(packet)
   138  		default:
   139  			err = E.New("ip: unknown version: ", ipVersion)
   140  		}
   141  		if err != nil {
   142  			m.logger.Trace(err)
   143  		}
   144  		release()
   145  	}
   146  }
   147  
   148  func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
   149  	switch packet.Protocol() {
   150  	case clashtcpip.TCP:
   151  		return m.processIPv4TCP(packet, packet.Payload())
   152  	case clashtcpip.UDP:
   153  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   154  			Payload: buffer.MakeWithData(packet),
   155  		})
   156  		m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
   157  		pkt.DecRef()
   158  		return nil
   159  	case clashtcpip.ICMP:
   160  		return m.processIPv4ICMP(packet, packet.Payload())
   161  	default:
   162  		return common.Error(m.tun.Write(packet))
   163  	}
   164  }
   165  
   166  func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
   167  	switch packet.Protocol() {
   168  	case clashtcpip.TCP:
   169  		return m.processIPv6TCP(packet, packet.Payload())
   170  	case clashtcpip.UDP:
   171  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   172  			Payload: buffer.MakeWithData(packet),
   173  		})
   174  		m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
   175  		pkt.DecRef()
   176  		return nil
   177  	case clashtcpip.ICMPv6:
   178  		return m.processIPv6ICMP(packet, packet.Payload())
   179  	default:
   180  		return common.Error(m.tun.Write(packet))
   181  	}
   182  }
   183  
   184  func (m *Mixed) packetLoop() {
   185  	for {
   186  		packet := m.endpoint.ReadContext(m.ctx)
   187  		if packet == nil {
   188  			break
   189  		}
   190  		bufio.WriteVectorised(m.writer, packet.AsSlices())
   191  		packet.DecRef()
   192  	}
   193  }
   194  
   195  func (m *Mixed) Close() error {
   196  	m.endpoint.Attach(nil)
   197  	m.stack.Close()
   198  	for _, endpoint := range m.stack.CleanupEndpoints() {
   199  		endpoint.Abort()
   200  	}
   201  	return m.System.Close()
   202  }