github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/tun.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package wireguard
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"net"
    12  	"net/netip"
    13  	"os"
    14  
    15  	"github.com/sagernet/wireguard-go/tun"
    16  	"github.com/moqsien/xraycore/features/dns"
    17  	"gvisor.dev/gvisor/pkg/buffer"
    18  	"gvisor.dev/gvisor/pkg/tcpip"
    19  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    20  	"gvisor.dev/gvisor/pkg/tcpip/header"
    21  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    22  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    23  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    24  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    25  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    27  )
    28  
    29  type netTun struct {
    30  	ep             *channel.Endpoint
    31  	stack          *stack.Stack
    32  	events         chan tun.Event
    33  	incomingPacket chan *buffer.View
    34  	mtu            int
    35  	dnsClient      dns.Client
    36  	hasV4, hasV6   bool
    37  }
    38  
    39  type Net netTun
    40  
    41  func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) {
    42  	opts := stack.Options{
    43  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    44  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    45  		HandleLocal:        true,
    46  	}
    47  	dev := &netTun{
    48  		ep:             channel.New(1024, uint32(mtu), ""),
    49  		stack:          stack.New(opts),
    50  		events:         make(chan tun.Event, 10),
    51  		incomingPacket: make(chan *buffer.View),
    52  		dnsClient:      dnsClient,
    53  		mtu:            mtu,
    54  	}
    55  	dev.ep.AddNotify(dev)
    56  	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
    57  	if tcpipErr != nil {
    58  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
    59  	}
    60  	for _, ip := range localAddresses {
    61  		var protoNumber tcpip.NetworkProtocolNumber
    62  		if ip.Is4() {
    63  			protoNumber = ipv4.ProtocolNumber
    64  		} else if ip.Is6() {
    65  			protoNumber = ipv6.ProtocolNumber
    66  		}
    67  		protoAddr := tcpip.ProtocolAddress{
    68  			Protocol:          protoNumber,
    69  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
    70  		}
    71  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    72  		if tcpipErr != nil {
    73  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    74  		}
    75  		if ip.Is4() {
    76  			dev.hasV4 = true
    77  		} else if ip.Is6() {
    78  			dev.hasV6 = true
    79  		}
    80  	}
    81  	if dev.hasV4 {
    82  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
    83  	}
    84  	if dev.hasV6 {
    85  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
    86  	}
    87  
    88  	dev.events <- tun.EventUp
    89  	return dev, (*Net)(dev), nil
    90  }
    91  
    92  func (tun *netTun) Name() (string, error) {
    93  	return "go", nil
    94  }
    95  
    96  func (tun *netTun) File() *os.File {
    97  	return nil
    98  }
    99  
   100  func (tun *netTun) Events() <-chan tun.Event {
   101  	return tun.events
   102  }
   103  
   104  func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
   105  	view, ok := <-tun.incomingPacket
   106  	if !ok {
   107  		return 0, os.ErrClosed
   108  	}
   109  
   110  	return view.Read(buf[0][offset:])
   111  }
   112  
   113  func (tun *netTun) Write(buf [][]byte, offset int) (count int, err error) {
   114  	for _, b := range buf {
   115  		packet := b[offset:]
   116  		if len(packet) == 0 {
   117  			continue
   118  		}
   119  
   120  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
   121  		switch packet[0] >> 4 {
   122  		case 4:
   123  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   124  		case 6:
   125  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   126  		}
   127  		count++
   128  	}
   129  	return
   130  }
   131  
   132  func (tun *netTun) WriteNotify() {
   133  	pkt := tun.ep.Read()
   134  	if pkt == nil {
   135  		return
   136  	}
   137  
   138  	view := pkt.ToView()
   139  	pkt.DecRef()
   140  
   141  	tun.incomingPacket <- view
   142  }
   143  
   144  func (tun *netTun) Flush() error {
   145  	return nil
   146  }
   147  
   148  func (tun *netTun) Close() error {
   149  	tun.stack.RemoveNIC(1)
   150  
   151  	if tun.events != nil {
   152  		close(tun.events)
   153  	}
   154  
   155  	tun.ep.Close()
   156  
   157  	if tun.incomingPacket != nil {
   158  		close(tun.incomingPacket)
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  func (tun *netTun) MTU() (int, error) {
   165  	return tun.mtu, nil
   166  }
   167  
   168  func (tun *netTun) BatchSize() int {
   169  	return 1
   170  }
   171  
   172  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   173  	var protoNumber tcpip.NetworkProtocolNumber
   174  	if endpoint.Addr().Is4() {
   175  		protoNumber = ipv4.ProtocolNumber
   176  	} else {
   177  		protoNumber = ipv6.ProtocolNumber
   178  	}
   179  	return tcpip.FullAddress{
   180  		NIC:  1,
   181  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
   182  		Port: endpoint.Port(),
   183  	}, protoNumber
   184  }
   185  
   186  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   187  	fa, pn := convertToFullAddr(addr)
   188  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   189  }
   190  
   191  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   192  	if addr == nil {
   193  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   194  	}
   195  	ip, _ := netip.AddrFromSlice(addr.IP)
   196  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   197  }
   198  
   199  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   200  	fa, pn := convertToFullAddr(addr)
   201  	return gonet.DialTCP(net.stack, fa, pn)
   202  }
   203  
   204  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   205  	if addr == nil {
   206  		return net.DialTCPAddrPort(netip.AddrPort{})
   207  	}
   208  	ip, _ := netip.AddrFromSlice(addr.IP)
   209  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   210  }
   211  
   212  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   213  	fa, pn := convertToFullAddr(addr)
   214  	return gonet.ListenTCP(net.stack, fa, pn)
   215  }
   216  
   217  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   218  	if addr == nil {
   219  		return net.ListenTCPAddrPort(netip.AddrPort{})
   220  	}
   221  	ip, _ := netip.AddrFromSlice(addr.IP)
   222  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   223  }
   224  
   225  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   226  	var lfa, rfa *tcpip.FullAddress
   227  	var pn tcpip.NetworkProtocolNumber
   228  	if laddr.IsValid() || laddr.Port() > 0 {
   229  		var addr tcpip.FullAddress
   230  		addr, pn = convertToFullAddr(laddr)
   231  		lfa = &addr
   232  	}
   233  	if raddr.IsValid() || raddr.Port() > 0 {
   234  		var addr tcpip.FullAddress
   235  		addr, pn = convertToFullAddr(raddr)
   236  		rfa = &addr
   237  	}
   238  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   239  }
   240  
   241  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   242  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   243  }
   244  
   245  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   246  	var la, ra netip.AddrPort
   247  	if laddr != nil {
   248  		ip, _ := netip.AddrFromSlice(laddr.IP)
   249  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   250  	}
   251  	if raddr != nil {
   252  		ip, _ := netip.AddrFromSlice(raddr.IP)
   253  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   254  	}
   255  	return net.DialUDPAddrPort(la, ra)
   256  }
   257  
   258  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   259  	return net.DialUDP(laddr, nil)
   260  }
   261  
   262  func (n *Net) HasV4() bool {
   263  	return n.hasV4
   264  }
   265  
   266  func (n *Net) HasV6() bool {
   267  	return n.hasV6
   268  }
   269  
   270  func IsDomainName(s string) bool {
   271  	l := len(s)
   272  	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
   273  		return false
   274  	}
   275  	last := byte('.')
   276  	nonNumeric := false
   277  	partlen := 0
   278  	for i := 0; i < len(s); i++ {
   279  		c := s[i]
   280  		switch {
   281  		default:
   282  			return false
   283  		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
   284  			nonNumeric = true
   285  			partlen++
   286  		case '0' <= c && c <= '9':
   287  			partlen++
   288  		case c == '-':
   289  			if last == '.' {
   290  				return false
   291  			}
   292  			partlen++
   293  			nonNumeric = true
   294  		case c == '.':
   295  			if last == '.' || last == '-' {
   296  				return false
   297  			}
   298  			if partlen > 63 || partlen == 0 {
   299  				return false
   300  			}
   301  			partlen = 0
   302  		}
   303  		last = c
   304  	}
   305  	if last == '-' || partlen > 63 {
   306  		return false
   307  	}
   308  	return nonNumeric
   309  }