github.com/sagernet/sing-box@v1.2.7/transport/wireguard/device_stack.go (about)

     1  //go:build with_gvisor
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  	"github.com/sagernet/wireguard-go/tun"
    15  
    16  	"gvisor.dev/gvisor/pkg/bufferv2"
    17  	"gvisor.dev/gvisor/pkg/tcpip"
    18  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    19  	"gvisor.dev/gvisor/pkg/tcpip/header"
    20  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    21  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    22  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    23  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    24  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    25  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    26  )
    27  
    28  var _ Device = (*StackDevice)(nil)
    29  
    30  const defaultNIC tcpip.NICID = 1
    31  
    32  type StackDevice struct {
    33  	stack      *stack.Stack
    34  	mtu        uint32
    35  	events     chan tun.Event
    36  	outbound   chan *stack.PacketBuffer
    37  	done       chan struct{}
    38  	dispatcher stack.NetworkDispatcher
    39  	addr4      tcpip.Address
    40  	addr6      tcpip.Address
    41  }
    42  
    43  func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
    44  	ipStack := stack.New(stack.Options{
    45  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    46  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
    47  		HandleLocal:        true,
    48  	})
    49  	tunDevice := &StackDevice{
    50  		stack:    ipStack,
    51  		mtu:      mtu,
    52  		events:   make(chan tun.Event, 1),
    53  		outbound: make(chan *stack.PacketBuffer, 256),
    54  		done:     make(chan struct{}),
    55  	}
    56  	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
    57  	if err != nil {
    58  		return nil, E.New(err.String())
    59  	}
    60  	for _, prefix := range localAddresses {
    61  		addr := tcpip.Address(prefix.Addr().AsSlice())
    62  		protoAddr := tcpip.ProtocolAddress{
    63  			AddressWithPrefix: tcpip.AddressWithPrefix{
    64  				Address:   addr,
    65  				PrefixLen: prefix.Bits(),
    66  			},
    67  		}
    68  		if prefix.Addr().Is4() {
    69  			tunDevice.addr4 = addr
    70  			protoAddr.Protocol = ipv4.ProtocolNumber
    71  		} else {
    72  			tunDevice.addr6 = addr
    73  			protoAddr.Protocol = ipv6.ProtocolNumber
    74  		}
    75  		err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
    76  		if err != nil {
    77  			return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
    78  		}
    79  	}
    80  	sOpt := tcpip.TCPSACKEnabled(true)
    81  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
    82  	cOpt := tcpip.CongestionControlOption("cubic")
    83  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
    84  	ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
    85  	ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
    86  	return tunDevice, nil
    87  }
    88  
    89  func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
    90  	return (*wireEndpoint)(w), nil
    91  }
    92  
    93  func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    94  	addr := tcpip.FullAddress{
    95  		NIC:  defaultNIC,
    96  		Port: destination.Port,
    97  		Addr: tcpip.Address(destination.Addr.AsSlice()),
    98  	}
    99  	bind := tcpip.FullAddress{
   100  		NIC: defaultNIC,
   101  	}
   102  	var networkProtocol tcpip.NetworkProtocolNumber
   103  	if destination.IsIPv4() {
   104  		networkProtocol = header.IPv4ProtocolNumber
   105  		bind.Addr = w.addr4
   106  	} else {
   107  		networkProtocol = header.IPv6ProtocolNumber
   108  		bind.Addr = w.addr6
   109  	}
   110  	switch N.NetworkName(network) {
   111  	case N.NetworkTCP:
   112  		tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		return tcpConn, nil
   117  	case N.NetworkUDP:
   118  		udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
   119  		if err != nil {
   120  			return nil, err
   121  		}
   122  		return udpConn, nil
   123  	default:
   124  		return nil, E.Extend(N.ErrUnknownNetwork, network)
   125  	}
   126  }
   127  
   128  func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   129  	bind := tcpip.FullAddress{
   130  		NIC: defaultNIC,
   131  	}
   132  	var networkProtocol tcpip.NetworkProtocolNumber
   133  	if destination.IsIPv4() || w.addr6 == "" {
   134  		networkProtocol = header.IPv4ProtocolNumber
   135  		bind.Addr = w.addr4
   136  	} else {
   137  		networkProtocol = header.IPv6ProtocolNumber
   138  		bind.Addr = w.addr6
   139  	}
   140  	udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	return udpConn, nil
   145  }
   146  
   147  func (w *StackDevice) Start() error {
   148  	w.events <- tun.EventUp
   149  	return nil
   150  }
   151  
   152  func (w *StackDevice) File() *os.File {
   153  	return nil
   154  }
   155  
   156  func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
   157  	select {
   158  	case packetBuffer, ok := <-w.outbound:
   159  		if !ok {
   160  			return 0, os.ErrClosed
   161  		}
   162  		defer packetBuffer.DecRef()
   163  		p = p[offset:]
   164  		for _, slice := range packetBuffer.AsSlices() {
   165  			n += copy(p[n:], slice)
   166  		}
   167  		return
   168  	case <-w.done:
   169  		return 0, os.ErrClosed
   170  	}
   171  }
   172  
   173  func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
   174  	p = p[offset:]
   175  	if len(p) == 0 {
   176  		return
   177  	}
   178  	var networkProtocol tcpip.NetworkProtocolNumber
   179  	switch header.IPVersion(p) {
   180  	case header.IPv4Version:
   181  		networkProtocol = header.IPv4ProtocolNumber
   182  	case header.IPv6Version:
   183  		networkProtocol = header.IPv6ProtocolNumber
   184  	}
   185  	packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
   186  		Payload: bufferv2.MakeWithData(p),
   187  	})
   188  	defer packetBuffer.DecRef()
   189  	w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
   190  	n = len(p)
   191  	return
   192  }
   193  
   194  func (w *StackDevice) Flush() error {
   195  	return nil
   196  }
   197  
   198  func (w *StackDevice) MTU() (int, error) {
   199  	return int(w.mtu), nil
   200  }
   201  
   202  func (w *StackDevice) Name() (string, error) {
   203  	return "sing-box", nil
   204  }
   205  
   206  func (w *StackDevice) Events() chan tun.Event {
   207  	return w.events
   208  }
   209  
   210  func (w *StackDevice) Close() error {
   211  	select {
   212  	case <-w.done:
   213  		return os.ErrClosed
   214  	default:
   215  	}
   216  	w.stack.Close()
   217  	for _, endpoint := range w.stack.CleanupEndpoints() {
   218  		endpoint.Abort()
   219  	}
   220  	w.stack.Wait()
   221  	close(w.done)
   222  	return nil
   223  }
   224  
   225  var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
   226  
   227  type wireEndpoint StackDevice
   228  
   229  func (ep *wireEndpoint) MTU() uint32 {
   230  	return ep.mtu
   231  }
   232  
   233  func (ep *wireEndpoint) MaxHeaderLength() uint16 {
   234  	return 0
   235  }
   236  
   237  func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
   238  	return ""
   239  }
   240  
   241  func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
   242  	return stack.CapabilityNone
   243  }
   244  
   245  func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
   246  	ep.dispatcher = dispatcher
   247  }
   248  
   249  func (ep *wireEndpoint) IsAttached() bool {
   250  	return ep.dispatcher != nil
   251  }
   252  
   253  func (ep *wireEndpoint) Wait() {
   254  }
   255  
   256  func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
   257  	return header.ARPHardwareNone
   258  }
   259  
   260  func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
   261  }
   262  
   263  func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
   264  	for _, packetBuffer := range list.AsSlice() {
   265  		packetBuffer.IncRef()
   266  		select {
   267  		case <-ep.done:
   268  			return 0, &tcpip.ErrClosedForSend{}
   269  		case ep.outbound <- packetBuffer:
   270  		}
   271  	}
   272  	return list.Len(), nil
   273  }