github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_gvisor.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"context"
     7  	"net/netip"
     8  	"time"
     9  
    10  	"github.com/sagernet/gvisor/pkg/tcpip"
    11  	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
    12  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    13  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
    14  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
    15  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    16  	"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
    17  	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
    18  	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
    19  	"github.com/sagernet/gvisor/pkg/waiter"
    20  	"github.com/sagernet/sing/common/bufio"
    21  	"github.com/sagernet/sing/common/canceler"
    22  	E "github.com/sagernet/sing/common/exceptions"
    23  	"github.com/sagernet/sing/common/logger"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  )
    26  
    27  const WithGVisor = true
    28  
    29  const defaultNIC tcpip.NICID = 1
    30  
    31  type GVisor struct {
    32  	ctx                    context.Context
    33  	tun                    GVisorTun
    34  	endpointIndependentNat bool
    35  	udpTimeout             int64
    36  	broadcastAddr          netip.Addr
    37  	handler                Handler
    38  	logger                 logger.Logger
    39  	stack                  *stack.Stack
    40  	endpoint               stack.LinkEndpoint
    41  }
    42  
    43  type GVisorTun interface {
    44  	Tun
    45  	NewEndpoint() (stack.LinkEndpoint, error)
    46  }
    47  
    48  func NewGVisor(
    49  	options StackOptions,
    50  ) (Stack, error) {
    51  	gTun, isGTun := options.Tun.(GVisorTun)
    52  	if !isGTun {
    53  		return nil, E.New("gVisor stack is unsupported on current platform")
    54  	}
    55  
    56  	gStack := &GVisor{
    57  		ctx:                    options.Context,
    58  		tun:                    gTun,
    59  		endpointIndependentNat: options.EndpointIndependentNat,
    60  		udpTimeout:             options.UDPTimeout,
    61  		broadcastAddr:          BroadcastAddr(options.TunOptions.Inet4Address),
    62  		handler:                options.Handler,
    63  		logger:                 options.Logger,
    64  	}
    65  	return gStack, nil
    66  }
    67  
    68  func (t *GVisor) Start() error {
    69  	linkEndpoint, err := t.tun.NewEndpoint()
    70  	if err != nil {
    71  		return err
    72  	}
    73  	linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
    74  	ipStack, err := newGVisorStack(linkEndpoint)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
    79  		var wq waiter.Queue
    80  		handshakeCtx, cancel := context.WithCancel(context.Background())
    81  		go func() {
    82  			select {
    83  			case <-t.ctx.Done():
    84  				wq.Notify(wq.Events())
    85  			case <-handshakeCtx.Done():
    86  			}
    87  		}()
    88  		endpoint, err := r.CreateEndpoint(&wq)
    89  		cancel()
    90  		if err != nil {
    91  			r.Complete(true)
    92  			return
    93  		}
    94  		r.Complete(false)
    95  		endpoint.SocketOptions().SetKeepAlive(true)
    96  		keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
    97  		endpoint.SetSockOpt(&keepAliveIdle)
    98  		keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
    99  		endpoint.SetSockOpt(&keepAliveInterval)
   100  		tcpConn := gonet.NewTCPConn(&wq, endpoint)
   101  		lAddr := tcpConn.RemoteAddr()
   102  		rAddr := tcpConn.LocalAddr()
   103  		if lAddr == nil || rAddr == nil {
   104  			tcpConn.Close()
   105  			return
   106  		}
   107  		go func() {
   108  			var metadata M.Metadata
   109  			metadata.Source = M.SocksaddrFromNet(lAddr)
   110  			metadata.Destination = M.SocksaddrFromNet(rAddr)
   111  			hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
   112  			if hErr != nil {
   113  				endpoint.Abort()
   114  			}
   115  		}()
   116  	})
   117  	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
   118  	if !t.endpointIndependentNat {
   119  		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
   120  			var wq waiter.Queue
   121  			endpoint, err := request.CreateEndpoint(&wq)
   122  			if err != nil {
   123  				return
   124  			}
   125  			udpConn := gonet.NewUDPConn(&wq, endpoint)
   126  			lAddr := udpConn.RemoteAddr()
   127  			rAddr := udpConn.LocalAddr()
   128  			if lAddr == nil || rAddr == nil {
   129  				endpoint.Abort()
   130  				return
   131  			}
   132  			gConn := &gUDPConn{UDPConn: udpConn}
   133  			go func() {
   134  				var metadata M.Metadata
   135  				metadata.Source = M.SocksaddrFromNet(lAddr)
   136  				metadata.Destination = M.SocksaddrFromNet(rAddr)
   137  				ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
   138  				hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
   139  				if hErr != nil {
   140  					endpoint.Abort()
   141  				}
   142  			}()
   143  		})
   144  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
   145  	} else {
   146  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
   147  	}
   148  
   149  	t.stack = ipStack
   150  	t.endpoint = linkEndpoint
   151  	return nil
   152  }
   153  
   154  func (t *GVisor) Close() error {
   155  	t.endpoint.Attach(nil)
   156  	t.stack.Close()
   157  	for _, endpoint := range t.stack.CleanupEndpoints() {
   158  		endpoint.Abort()
   159  	}
   160  	return nil
   161  }
   162  
   163  func AddressFromAddr(destination netip.Addr) tcpip.Address {
   164  	if destination.Is6() {
   165  		return tcpip.AddrFrom16(destination.As16())
   166  	} else {
   167  		return tcpip.AddrFrom4(destination.As4())
   168  	}
   169  }
   170  
   171  func AddrFromAddress(address tcpip.Address) netip.Addr {
   172  	if address.Len() == 16 {
   173  		return netip.AddrFrom16(address.As16())
   174  	} else {
   175  		return netip.AddrFrom4(address.As4())
   176  	}
   177  }
   178  
   179  func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
   180  	ipStack := stack.New(stack.Options{
   181  		NetworkProtocols: []stack.NetworkProtocolFactory{
   182  			ipv4.NewProtocol,
   183  			ipv6.NewProtocol,
   184  		},
   185  		TransportProtocols: []stack.TransportProtocolFactory{
   186  			tcp.NewProtocol,
   187  			udp.NewProtocol,
   188  			icmp.NewProtocol4,
   189  			icmp.NewProtocol6,
   190  		},
   191  	})
   192  	tErr := ipStack.CreateNIC(defaultNIC, ep)
   193  	if tErr != nil {
   194  		return nil, E.New("create nic: ", wrapStackError(tErr))
   195  	}
   196  	ipStack.SetRouteTable([]tcpip.Route{
   197  		{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
   198  		{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
   199  	})
   200  	ipStack.SetSpoofing(defaultNIC, true)
   201  	ipStack.SetPromiscuousMode(defaultNIC, true)
   202  	bufSize := 20 * 1024
   203  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
   204  		Min:     1,
   205  		Default: bufSize,
   206  		Max:     bufSize,
   207  	})
   208  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
   209  		Min:     1,
   210  		Default: bufSize,
   211  		Max:     bufSize,
   212  	})
   213  	sOpt := tcpip.TCPSACKEnabled(true)
   214  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
   215  	mOpt := tcpip.TCPModerateReceiveBufferOption(true)
   216  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
   217  	return ipStack, nil
   218  }