github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/tun.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package wireguard
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"net"
    12  	"net/netip"
    13  	"os"
    14  
    15  	"github.com/sagernet/wireguard-go/tun"
    16  	"github.com/xraypb/Xray-core/features/dns"
    17  	"gvisor.dev/gvisor/pkg/bufferv2"
    18  	"gvisor.dev/gvisor/pkg/tcpip"
    19  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    20  	"gvisor.dev/gvisor/pkg/tcpip/header"
    21  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    22  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    23  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    24  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    25  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    27  )
    28  
    29  type netTun struct {
    30  	ep             *channel.Endpoint
    31  	stack          *stack.Stack
    32  	events         chan tun.Event
    33  	incomingPacket chan *bufferv2.View
    34  	mtu            int
    35  	dnsClient      dns.Client
    36  	hasV4, hasV6   bool
    37  }
    38  
    39  type Net netTun
    40  
    41  func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) {
    42  	opts := stack.Options{
    43  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    44  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    45  		HandleLocal:        true,
    46  	}
    47  	dev := &netTun{
    48  		ep:             channel.New(1024, uint32(mtu), ""),
    49  		stack:          stack.New(opts),
    50  		events:         make(chan tun.Event, 10),
    51  		incomingPacket: make(chan *bufferv2.View),
    52  		dnsClient:      dnsClient,
    53  		mtu:            mtu,
    54  	}
    55  	dev.ep.AddNotify(dev)
    56  	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
    57  	if tcpipErr != nil {
    58  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
    59  	}
    60  	for _, ip := range localAddresses {
    61  		var protoNumber tcpip.NetworkProtocolNumber
    62  		if ip.Is4() {
    63  			protoNumber = ipv4.ProtocolNumber
    64  		} else if ip.Is6() {
    65  			protoNumber = ipv6.ProtocolNumber
    66  		}
    67  		protoAddr := tcpip.ProtocolAddress{
    68  			Protocol:          protoNumber,
    69  			AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
    70  		}
    71  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    72  		if tcpipErr != nil {
    73  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    74  		}
    75  		if ip.Is4() {
    76  			dev.hasV4 = true
    77  		} else if ip.Is6() {
    78  			dev.hasV6 = true
    79  		}
    80  	}
    81  	if dev.hasV4 {
    82  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
    83  	}
    84  	if dev.hasV6 {
    85  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
    86  	}
    87  
    88  	dev.events <- tun.EventUp
    89  	return dev, (*Net)(dev), nil
    90  }
    91  
    92  func (tun *netTun) Name() (string, error) {
    93  	return "go", nil
    94  }
    95  
    96  func (tun *netTun) File() *os.File {
    97  	return nil
    98  }
    99  
   100  func (tun *netTun) Events() chan tun.Event {
   101  	return tun.events
   102  }
   103  
   104  func (tun *netTun) Read(buf []byte, offset int) (int, error) {
   105  	view, ok := <-tun.incomingPacket
   106  	if !ok {
   107  		return 0, os.ErrClosed
   108  	}
   109  
   110  	return view.Read(buf[offset:])
   111  }
   112  
   113  func (tun *netTun) Write(buf []byte, offset int) (int, error) {
   114  	packet := buf[offset:]
   115  	if len(packet) == 0 {
   116  		return 0, nil
   117  	}
   118  
   119  	pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
   120  	switch packet[0] >> 4 {
   121  	case 4:
   122  		tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   123  	case 6:
   124  		tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   125  	}
   126  
   127  	return len(buf), nil
   128  }
   129  
   130  func (tun *netTun) WriteNotify() {
   131  	pkt := tun.ep.Read()
   132  	if pkt == nil {
   133  		return
   134  	}
   135  
   136  	view := pkt.ToView()
   137  	pkt.DecRef()
   138  
   139  	tun.incomingPacket <- view
   140  }
   141  
   142  func (tun *netTun) Flush() error {
   143  	return nil
   144  }
   145  
   146  func (tun *netTun) Close() error {
   147  	tun.stack.RemoveNIC(1)
   148  
   149  	if tun.events != nil {
   150  		close(tun.events)
   151  	}
   152  
   153  	tun.ep.Close()
   154  
   155  	if tun.incomingPacket != nil {
   156  		close(tun.incomingPacket)
   157  	}
   158  
   159  	return nil
   160  }
   161  
   162  func (tun *netTun) MTU() (int, error) {
   163  	return tun.mtu, nil
   164  }
   165  
   166  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   167  	var protoNumber tcpip.NetworkProtocolNumber
   168  	if endpoint.Addr().Is4() {
   169  		protoNumber = ipv4.ProtocolNumber
   170  	} else {
   171  		protoNumber = ipv6.ProtocolNumber
   172  	}
   173  	return tcpip.FullAddress{
   174  		NIC:  1,
   175  		Addr: tcpip.Address(endpoint.Addr().AsSlice()),
   176  		Port: endpoint.Port(),
   177  	}, protoNumber
   178  }
   179  
   180  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   181  	fa, pn := convertToFullAddr(addr)
   182  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   183  }
   184  
   185  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   186  	if addr == nil {
   187  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   188  	}
   189  	ip, _ := netip.AddrFromSlice(addr.IP)
   190  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   191  }
   192  
   193  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   194  	fa, pn := convertToFullAddr(addr)
   195  	return gonet.DialTCP(net.stack, fa, pn)
   196  }
   197  
   198  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   199  	if addr == nil {
   200  		return net.DialTCPAddrPort(netip.AddrPort{})
   201  	}
   202  	ip, _ := netip.AddrFromSlice(addr.IP)
   203  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   204  }
   205  
   206  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   207  	fa, pn := convertToFullAddr(addr)
   208  	return gonet.ListenTCP(net.stack, fa, pn)
   209  }
   210  
   211  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   212  	if addr == nil {
   213  		return net.ListenTCPAddrPort(netip.AddrPort{})
   214  	}
   215  	ip, _ := netip.AddrFromSlice(addr.IP)
   216  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   217  }
   218  
   219  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   220  	var lfa, rfa *tcpip.FullAddress
   221  	var pn tcpip.NetworkProtocolNumber
   222  	if laddr.IsValid() || laddr.Port() > 0 {
   223  		var addr tcpip.FullAddress
   224  		addr, pn = convertToFullAddr(laddr)
   225  		lfa = &addr
   226  	}
   227  	if raddr.IsValid() || raddr.Port() > 0 {
   228  		var addr tcpip.FullAddress
   229  		addr, pn = convertToFullAddr(raddr)
   230  		rfa = &addr
   231  	}
   232  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   233  }
   234  
   235  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   236  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   237  }
   238  
   239  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   240  	var la, ra netip.AddrPort
   241  	if laddr != nil {
   242  		ip, _ := netip.AddrFromSlice(laddr.IP)
   243  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   244  	}
   245  	if raddr != nil {
   246  		ip, _ := netip.AddrFromSlice(raddr.IP)
   247  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   248  	}
   249  	return net.DialUDPAddrPort(la, ra)
   250  }
   251  
   252  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   253  	return net.DialUDP(laddr, nil)
   254  }
   255  
   256  func (n *Net) HasV4() bool {
   257  	return n.hasV4
   258  }
   259  
   260  func (n *Net) HasV6() bool {
   261  	return n.hasV6
   262  }
   263  
   264  func IsDomainName(s string) bool {
   265  	l := len(s)
   266  	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
   267  		return false
   268  	}
   269  	last := byte('.')
   270  	nonNumeric := false
   271  	partlen := 0
   272  	for i := 0; i < len(s); i++ {
   273  		c := s[i]
   274  		switch {
   275  		default:
   276  			return false
   277  		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
   278  			nonNumeric = true
   279  			partlen++
   280  		case '0' <= c && c <= '9':
   281  			partlen++
   282  		case c == '-':
   283  			if last == '.' {
   284  				return false
   285  			}
   286  			partlen++
   287  			nonNumeric = true
   288  		case c == '.':
   289  			if last == '.' || last == '-' {
   290  				return false
   291  			}
   292  			if partlen > 63 || partlen == 0 {
   293  				return false
   294  			}
   295  			partlen = 0
   296  		}
   297  		last = c
   298  	}
   299  	if last == '-' || partlen > 63 {
   300  		return false
   301  	}
   302  	return nonNumeric
   303  }