github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_gvisor.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"context"
     7  	"net/netip"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing/common/bufio"
    11  	"github.com/sagernet/sing/common/canceler"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	"github.com/sagernet/sing/common/logger"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  
    16  	"github.com/metacubex/gvisor/pkg/tcpip"
    17  	"github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet"
    18  	"github.com/metacubex/gvisor/pkg/tcpip/header"
    19  	"github.com/metacubex/gvisor/pkg/tcpip/network/ipv4"
    20  	"github.com/metacubex/gvisor/pkg/tcpip/network/ipv6"
    21  	"github.com/metacubex/gvisor/pkg/tcpip/stack"
    22  	"github.com/metacubex/gvisor/pkg/tcpip/transport/icmp"
    23  	"github.com/metacubex/gvisor/pkg/tcpip/transport/tcp"
    24  	"github.com/metacubex/gvisor/pkg/tcpip/transport/udp"
    25  	"github.com/metacubex/gvisor/pkg/waiter"
    26  )
    27  
    28  const WithGVisor = true
    29  
    30  const defaultNIC tcpip.NICID = 1
    31  
    32  type GVisor struct {
    33  	ctx                    context.Context
    34  	tun                    GVisorTun
    35  	endpointIndependentNat bool
    36  	udpTimeout             int64
    37  	broadcastAddr          netip.Addr
    38  	handler                Handler
    39  	logger                 logger.Logger
    40  	stack                  *stack.Stack
    41  	endpoint               stack.LinkEndpoint
    42  }
    43  
    44  type GVisorTun interface {
    45  	Tun
    46  	NewEndpoint() (stack.LinkEndpoint, error)
    47  }
    48  
    49  func NewGVisor(
    50  	options StackOptions,
    51  ) (Stack, error) {
    52  	gTun, isGTun := options.Tun.(GVisorTun)
    53  	if !isGTun {
    54  		return nil, E.New("gVisor stack is unsupported on current platform")
    55  	}
    56  
    57  	gStack := &GVisor{
    58  		ctx:                    options.Context,
    59  		tun:                    gTun,
    60  		endpointIndependentNat: options.EndpointIndependentNat,
    61  		udpTimeout:             options.UDPTimeout,
    62  		broadcastAddr:          BroadcastAddr(options.TunOptions.Inet4Address),
    63  		handler:                options.Handler,
    64  		logger:                 options.Logger,
    65  	}
    66  	return gStack, nil
    67  }
    68  
    69  func (t *GVisor) Start() error {
    70  	linkEndpoint, err := t.tun.NewEndpoint()
    71  	if err != nil {
    72  		return err
    73  	}
    74  	linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
    75  	ipStack, err := newGVisorStack(linkEndpoint)
    76  	if err != nil {
    77  		return err
    78  	}
    79  	tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
    80  		var wq waiter.Queue
    81  		handshakeCtx, cancel := context.WithCancel(context.Background())
    82  		go func() {
    83  			select {
    84  			case <-t.ctx.Done():
    85  				wq.Notify(wq.Events())
    86  			case <-handshakeCtx.Done():
    87  			}
    88  		}()
    89  		endpoint, err := r.CreateEndpoint(&wq)
    90  		cancel()
    91  		if err != nil {
    92  			r.Complete(true)
    93  			return
    94  		}
    95  		r.Complete(false)
    96  		endpoint.SocketOptions().SetKeepAlive(true)
    97  		keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
    98  		endpoint.SetSockOpt(&keepAliveIdle)
    99  		keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
   100  		endpoint.SetSockOpt(&keepAliveInterval)
   101  		tcpConn := gonet.NewTCPConn(&wq, endpoint)
   102  		lAddr := tcpConn.RemoteAddr()
   103  		rAddr := tcpConn.LocalAddr()
   104  		if lAddr == nil || rAddr == nil {
   105  			tcpConn.Close()
   106  			return
   107  		}
   108  		go func() {
   109  			var metadata M.Metadata
   110  			metadata.Source = M.SocksaddrFromNet(lAddr)
   111  			metadata.Destination = M.SocksaddrFromNet(rAddr)
   112  			hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
   113  			if hErr != nil {
   114  				endpoint.Abort()
   115  			}
   116  		}()
   117  	})
   118  	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
   119  	if !t.endpointIndependentNat {
   120  		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
   121  			var wq waiter.Queue
   122  			endpoint, err := request.CreateEndpoint(&wq)
   123  			if err != nil {
   124  				return
   125  			}
   126  			udpConn := gonet.NewUDPConn(&wq, endpoint)
   127  			lAddr := udpConn.RemoteAddr()
   128  			rAddr := udpConn.LocalAddr()
   129  			if lAddr == nil || rAddr == nil {
   130  				endpoint.Abort()
   131  				return
   132  			}
   133  			gConn := &gUDPConn{UDPConn: udpConn}
   134  			go func() {
   135  				var metadata M.Metadata
   136  				metadata.Source = M.SocksaddrFromNet(lAddr)
   137  				metadata.Destination = M.SocksaddrFromNet(rAddr)
   138  				ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
   139  				hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
   140  				if hErr != nil {
   141  					endpoint.Abort()
   142  				}
   143  			}()
   144  		})
   145  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
   146  	} else {
   147  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
   148  	}
   149  
   150  	t.stack = ipStack
   151  	t.endpoint = linkEndpoint
   152  	return nil
   153  }
   154  
   155  func (t *GVisor) Close() error {
   156  	t.endpoint.Attach(nil)
   157  	t.stack.Close()
   158  	for _, endpoint := range t.stack.CleanupEndpoints() {
   159  		endpoint.Abort()
   160  	}
   161  	return nil
   162  }
   163  
   164  func AddressFromAddr(destination netip.Addr) tcpip.Address {
   165  	if destination.Is6() {
   166  		return tcpip.AddrFrom16(destination.As16())
   167  	} else {
   168  		return tcpip.AddrFrom4(destination.As4())
   169  	}
   170  }
   171  
   172  func AddrFromAddress(address tcpip.Address) netip.Addr {
   173  	if address.Len() == 16 {
   174  		return netip.AddrFrom16(address.As16())
   175  	} else {
   176  		return netip.AddrFrom4(address.As4())
   177  	}
   178  }
   179  
   180  func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
   181  	ipStack := stack.New(stack.Options{
   182  		NetworkProtocols: []stack.NetworkProtocolFactory{
   183  			ipv4.NewProtocol,
   184  			ipv6.NewProtocol,
   185  		},
   186  		TransportProtocols: []stack.TransportProtocolFactory{
   187  			tcp.NewProtocol,
   188  			udp.NewProtocol,
   189  			icmp.NewProtocol4,
   190  			icmp.NewProtocol6,
   191  		},
   192  	})
   193  	tErr := ipStack.CreateNIC(defaultNIC, ep)
   194  	if tErr != nil {
   195  		return nil, E.New("create nic: ", wrapStackError(tErr))
   196  	}
   197  	ipStack.SetRouteTable([]tcpip.Route{
   198  		{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
   199  		{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
   200  	})
   201  	ipStack.SetSpoofing(defaultNIC, true)
   202  	ipStack.SetPromiscuousMode(defaultNIC, true)
   203  	bufSize := 20 * 1024
   204  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
   205  		Min:     1,
   206  		Default: bufSize,
   207  		Max:     bufSize,
   208  	})
   209  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
   210  		Min:     1,
   211  		Default: bufSize,
   212  		Max:     bufSize,
   213  	})
   214  	sOpt := tcpip.TCPSACKEnabled(true)
   215  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
   216  	mOpt := tcpip.TCPModerateReceiveBufferOption(true)
   217  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
   218  	return ipStack, nil
   219  }