github.com/sagernet/sing-box@v1.9.0-rc.20/transport/wireguard/device_stack.go (about)

     1  //go:build with_gvisor
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  
    11  	"github.com/sagernet/gvisor/pkg/buffer"
    12  	"github.com/sagernet/gvisor/pkg/tcpip"
    13  	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
    14  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    15  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
    16  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
    17  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    18  	"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
    19  	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
    20  	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
    21  	"github.com/sagernet/sing-tun"
    22  	"github.com/sagernet/sing/common/buf"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	wgTun "github.com/sagernet/wireguard-go/tun"
    27  )
    28  
    29  var _ Device = (*StackDevice)(nil)
    30  
    31  const defaultNIC tcpip.NICID = 1
    32  
    33  type StackDevice struct {
    34  	stack          *stack.Stack
    35  	mtu            uint32
    36  	events         chan wgTun.Event
    37  	outbound       chan *stack.PacketBuffer
    38  	packetOutbound chan *buf.Buffer
    39  	done           chan struct{}
    40  	dispatcher     stack.NetworkDispatcher
    41  	addr4          tcpip.Address
    42  	addr6          tcpip.Address
    43  }
    44  
    45  func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
    46  	ipStack := stack.New(stack.Options{
    47  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    48  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
    49  		HandleLocal:        true,
    50  	})
    51  	tunDevice := &StackDevice{
    52  		stack:          ipStack,
    53  		mtu:            mtu,
    54  		events:         make(chan wgTun.Event, 1),
    55  		outbound:       make(chan *stack.PacketBuffer, 256),
    56  		packetOutbound: make(chan *buf.Buffer, 256),
    57  		done:           make(chan struct{}),
    58  	}
    59  	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
    60  	if err != nil {
    61  		return nil, E.New(err.String())
    62  	}
    63  	for _, prefix := range localAddresses {
    64  		addr := tun.AddressFromAddr(prefix.Addr())
    65  		protoAddr := tcpip.ProtocolAddress{
    66  			AddressWithPrefix: tcpip.AddressWithPrefix{
    67  				Address:   addr,
    68  				PrefixLen: prefix.Bits(),
    69  			},
    70  		}
    71  		if prefix.Addr().Is4() {
    72  			tunDevice.addr4 = addr
    73  			protoAddr.Protocol = ipv4.ProtocolNumber
    74  		} else {
    75  			tunDevice.addr6 = addr
    76  			protoAddr.Protocol = ipv6.ProtocolNumber
    77  		}
    78  		err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
    79  		if err != nil {
    80  			return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
    81  		}
    82  	}
    83  	sOpt := tcpip.TCPSACKEnabled(true)
    84  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
    85  	cOpt := tcpip.CongestionControlOption("cubic")
    86  	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
    87  	ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
    88  	ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
    89  	return tunDevice, nil
    90  }
    91  
    92  func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
    93  	return (*wireEndpoint)(w), nil
    94  }
    95  
    96  func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    97  	addr := tcpip.FullAddress{
    98  		NIC:  defaultNIC,
    99  		Port: destination.Port,
   100  		Addr: tun.AddressFromAddr(destination.Addr),
   101  	}
   102  	bind := tcpip.FullAddress{
   103  		NIC: defaultNIC,
   104  	}
   105  	var networkProtocol tcpip.NetworkProtocolNumber
   106  	if destination.IsIPv4() {
   107  		networkProtocol = header.IPv4ProtocolNumber
   108  		bind.Addr = w.addr4
   109  	} else {
   110  		networkProtocol = header.IPv6ProtocolNumber
   111  		bind.Addr = w.addr6
   112  	}
   113  	switch N.NetworkName(network) {
   114  	case N.NetworkTCP:
   115  		tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		return tcpConn, nil
   120  	case N.NetworkUDP:
   121  		udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		return udpConn, nil
   126  	default:
   127  		return nil, E.Extend(N.ErrUnknownNetwork, network)
   128  	}
   129  }
   130  
   131  func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   132  	bind := tcpip.FullAddress{
   133  		NIC: defaultNIC,
   134  	}
   135  	var networkProtocol tcpip.NetworkProtocolNumber
   136  	if destination.IsIPv4() {
   137  		networkProtocol = header.IPv4ProtocolNumber
   138  		bind.Addr = w.addr4
   139  	} else {
   140  		networkProtocol = header.IPv6ProtocolNumber
   141  		bind.Addr = w.addr6
   142  	}
   143  	udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	return udpConn, nil
   148  }
   149  
   150  func (w *StackDevice) Inet4Address() netip.Addr {
   151  	return tun.AddrFromAddress(w.addr4)
   152  }
   153  
   154  func (w *StackDevice) Inet6Address() netip.Addr {
   155  	return tun.AddrFromAddress(w.addr6)
   156  }
   157  
   158  func (w *StackDevice) Start() error {
   159  	w.events <- wgTun.EventUp
   160  	return nil
   161  }
   162  
   163  func (w *StackDevice) File() *os.File {
   164  	return nil
   165  }
   166  
   167  func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
   168  	select {
   169  	case packetBuffer, ok := <-w.outbound:
   170  		if !ok {
   171  			return 0, os.ErrClosed
   172  		}
   173  		defer packetBuffer.DecRef()
   174  		p := bufs[0]
   175  		p = p[offset:]
   176  		n := 0
   177  		for _, slice := range packetBuffer.AsSlices() {
   178  			n += copy(p[n:], slice)
   179  		}
   180  		sizes[0] = n
   181  		count = 1
   182  		return
   183  	case packet := <-w.packetOutbound:
   184  		defer packet.Release()
   185  		sizes[0] = copy(bufs[0][offset:], packet.Bytes())
   186  		count = 1
   187  		return
   188  	case <-w.done:
   189  		return 0, os.ErrClosed
   190  	}
   191  }
   192  
   193  func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
   194  	for _, b := range bufs {
   195  		b = b[offset:]
   196  		if len(b) == 0 {
   197  			continue
   198  		}
   199  		var networkProtocol tcpip.NetworkProtocolNumber
   200  		switch header.IPVersion(b) {
   201  		case header.IPv4Version:
   202  			networkProtocol = header.IPv4ProtocolNumber
   203  		case header.IPv6Version:
   204  			networkProtocol = header.IPv6ProtocolNumber
   205  		}
   206  		packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
   207  			Payload: buffer.MakeWithData(b),
   208  		})
   209  		w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
   210  		packetBuffer.DecRef()
   211  		count++
   212  	}
   213  	return
   214  }
   215  
   216  func (w *StackDevice) Flush() error {
   217  	return nil
   218  }
   219  
   220  func (w *StackDevice) MTU() (int, error) {
   221  	return int(w.mtu), nil
   222  }
   223  
   224  func (w *StackDevice) Name() (string, error) {
   225  	return "sing-box", nil
   226  }
   227  
   228  func (w *StackDevice) Events() <-chan wgTun.Event {
   229  	return w.events
   230  }
   231  
   232  func (w *StackDevice) Close() error {
   233  	select {
   234  	case <-w.done:
   235  		return os.ErrClosed
   236  	default:
   237  	}
   238  	w.stack.Close()
   239  	for _, endpoint := range w.stack.CleanupEndpoints() {
   240  		endpoint.Abort()
   241  	}
   242  	w.stack.Wait()
   243  	close(w.done)
   244  	return nil
   245  }
   246  
   247  func (w *StackDevice) BatchSize() int {
   248  	return 1
   249  }
   250  
   251  var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
   252  
   253  type wireEndpoint StackDevice
   254  
   255  func (ep *wireEndpoint) MTU() uint32 {
   256  	return ep.mtu
   257  }
   258  
   259  func (ep *wireEndpoint) MaxHeaderLength() uint16 {
   260  	return 0
   261  }
   262  
   263  func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
   264  	return ""
   265  }
   266  
   267  func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
   268  	return stack.CapabilityRXChecksumOffload
   269  }
   270  
   271  func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
   272  	ep.dispatcher = dispatcher
   273  }
   274  
   275  func (ep *wireEndpoint) IsAttached() bool {
   276  	return ep.dispatcher != nil
   277  }
   278  
   279  func (ep *wireEndpoint) Wait() {
   280  }
   281  
   282  func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
   283  	return header.ARPHardwareNone
   284  }
   285  
   286  func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
   287  }
   288  
   289  func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
   290  	return true
   291  }
   292  
   293  func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
   294  	for _, packetBuffer := range list.AsSlice() {
   295  		packetBuffer.IncRef()
   296  		select {
   297  		case <-ep.done:
   298  			return 0, &tcpip.ErrClosedForSend{}
   299  		case ep.outbound <- packetBuffer:
   300  		}
   301  	}
   302  	return list.Len(), nil
   303  }