github.com/yaling888/clash@v1.53.0/transport/wireguard/tun.go (about)

     1  //go:build !nogvisor
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"net/netip"
    10  	"os"
    11  	"syscall"
    12  
    13  	"golang.zx2c4.com/wireguard/tun"
    14  	"gvisor.dev/gvisor/pkg/buffer"
    15  	"gvisor.dev/gvisor/pkg/tcpip"
    16  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    17  	"gvisor.dev/gvisor/pkg/tcpip/header"
    18  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    19  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    20  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    21  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    22  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    23  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    24  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    25  )
    26  
    27  type netTun struct {
    28  	ep             *channel.Endpoint
    29  	stack          *stack.Stack
    30  	events         chan tun.Event
    31  	incomingPacket chan *buffer.View
    32  	mtu            int
    33  	dnsServers     []netip.Addr
    34  	hasV4, hasV6   bool
    35  }
    36  
    37  type Net netTun
    38  
    39  func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
    40  	opts := stack.Options{
    41  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    42  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
    43  		HandleLocal:        true,
    44  	}
    45  	dev := &netTun{
    46  		ep:             channel.New(1024, uint32(mtu), ""),
    47  		stack:          stack.New(opts),
    48  		events:         make(chan tun.Event, 10),
    49  		incomingPacket: make(chan *buffer.View),
    50  		dnsServers:     dnsServers,
    51  		mtu:            mtu,
    52  	}
    53  	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
    54  	tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
    55  	if tcpipErr != nil {
    56  		return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
    57  	}
    58  	dev.ep.AddNotify(dev)
    59  	tcpipErr = dev.stack.CreateNIC(1, dev.ep)
    60  	if tcpipErr != nil {
    61  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
    62  	}
    63  	for _, ip := range localAddresses {
    64  		var protoNumber tcpip.NetworkProtocolNumber
    65  		if ip.Is4() {
    66  			protoNumber = ipv4.ProtocolNumber
    67  		} else if ip.Is6() {
    68  			protoNumber = ipv6.ProtocolNumber
    69  		}
    70  		protoAddr := tcpip.ProtocolAddress{
    71  			Protocol:          protoNumber,
    72  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
    73  		}
    74  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    75  		if tcpipErr != nil {
    76  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    77  		}
    78  		if ip.Is4() {
    79  			dev.hasV4 = true
    80  		} else if ip.Is6() {
    81  			dev.hasV6 = true
    82  		}
    83  	}
    84  	if dev.hasV4 {
    85  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
    86  	}
    87  	if dev.hasV6 {
    88  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
    89  	}
    90  
    91  	dev.events <- tun.EventUp
    92  	return dev, (*Net)(dev), nil
    93  }
    94  
    95  func (tun *netTun) Name() (string, error) {
    96  	return "Clash-WireGuard", nil
    97  }
    98  
    99  func (tun *netTun) File() *os.File {
   100  	return nil
   101  }
   102  
   103  func (tun *netTun) Events() <-chan tun.Event {
   104  	return tun.events
   105  }
   106  
   107  func (tun *netTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
   108  	view, ok := <-tun.incomingPacket
   109  	if !ok {
   110  		return 0, os.ErrClosed
   111  	}
   112  
   113  	defer view.Release()
   114  
   115  	n, err := view.Read(bufs[0][offset:])
   116  	if err != nil {
   117  		return 0, err
   118  	}
   119  	sizes[0] = n
   120  	return 1, nil
   121  }
   122  
   123  func (tun *netTun) Write(bufs [][]byte, offset int) (int, error) {
   124  	for i, buf := range bufs {
   125  		packet := buf[offset:]
   126  		if len(packet) == 0 {
   127  			continue
   128  		}
   129  
   130  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
   131  		switch packet[0] >> 4 {
   132  		case 4:
   133  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   134  		case 6:
   135  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   136  		default:
   137  			pkb.DecRef()
   138  			return i, syscall.EAFNOSUPPORT
   139  		}
   140  		pkb.DecRef()
   141  	}
   142  	return len(bufs), nil
   143  }
   144  
   145  func (tun *netTun) WriteNotify() {
   146  	pkt := tun.ep.Read()
   147  	if pkt == nil {
   148  		return
   149  	}
   150  
   151  	view := pkt.ToView()
   152  	pkt.DecRef()
   153  
   154  	tun.incomingPacket <- view
   155  }
   156  
   157  func (tun *netTun) Close() error {
   158  	tun.stack.Destroy()
   159  
   160  	if tun.events != nil {
   161  		close(tun.events)
   162  	}
   163  
   164  	tun.ep.Close()
   165  
   166  	if tun.incomingPacket != nil {
   167  		close(tun.incomingPacket)
   168  	}
   169  
   170  	return nil
   171  }
   172  
   173  func (tun *netTun) MTU() (int, error) {
   174  	return tun.mtu, nil
   175  }
   176  
   177  func (tun *netTun) BatchSize() int {
   178  	return 1
   179  }
   180  
   181  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   182  	var protoNumber tcpip.NetworkProtocolNumber
   183  	if endpoint.Addr().Is4() {
   184  		protoNumber = ipv4.ProtocolNumber
   185  	} else {
   186  		protoNumber = ipv6.ProtocolNumber
   187  	}
   188  	return tcpip.FullAddress{
   189  		NIC:  1,
   190  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
   191  		Port: endpoint.Port(),
   192  	}, protoNumber
   193  }
   194  
   195  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   196  	fa, pn := convertToFullAddr(addr)
   197  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   198  }
   199  
   200  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   201  	if addr == nil {
   202  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   203  	}
   204  	ip, _ := netip.AddrFromSlice(addr.IP)
   205  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   206  }
   207  
   208  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   209  	fa, pn := convertToFullAddr(addr)
   210  	return gonet.DialTCP(net.stack, fa, pn)
   211  }
   212  
   213  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   214  	if addr == nil {
   215  		return net.DialTCPAddrPort(netip.AddrPort{})
   216  	}
   217  	ip, _ := netip.AddrFromSlice(addr.IP)
   218  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   219  }
   220  
   221  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   222  	fa, pn := convertToFullAddr(addr)
   223  	return gonet.ListenTCP(net.stack, fa, pn)
   224  }
   225  
   226  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   227  	if addr == nil {
   228  		return net.ListenTCPAddrPort(netip.AddrPort{})
   229  	}
   230  	ip, _ := netip.AddrFromSlice(addr.IP)
   231  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   232  }
   233  
   234  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   235  	var lfa, rfa *tcpip.FullAddress
   236  	var pn tcpip.NetworkProtocolNumber
   237  	if laddr.IsValid() || laddr.Port() > 0 {
   238  		var addr tcpip.FullAddress
   239  		addr, pn = convertToFullAddr(laddr)
   240  		lfa = &addr
   241  	}
   242  	if raddr.IsValid() || raddr.Port() > 0 {
   243  		var addr tcpip.FullAddress
   244  		addr, pn = convertToFullAddr(raddr)
   245  		rfa = &addr
   246  	}
   247  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   248  }
   249  
   250  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   251  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   252  }
   253  
   254  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   255  	var la, ra netip.AddrPort
   256  	if laddr != nil {
   257  		ip, _ := netip.AddrFromSlice(laddr.IP)
   258  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   259  	}
   260  	if raddr != nil {
   261  		ip, _ := netip.AddrFromSlice(raddr.IP)
   262  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   263  	}
   264  	return net.DialUDPAddrPort(la, ra)
   265  }
   266  
   267  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   268  	return net.DialUDP(laddr, nil)
   269  }