github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/gvisortun/tun.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package gvisortun
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"net/netip"
    12  	"os"
    13  	"syscall"
    14  
    15  	"golang.zx2c4.com/wireguard/tun"
    16  	"gvisor.dev/gvisor/pkg/buffer"
    17  	"gvisor.dev/gvisor/pkg/tcpip"
    18  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    19  	"gvisor.dev/gvisor/pkg/tcpip/header"
    20  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    21  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    22  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    23  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    24  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    25  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    27  )
    28  
    29  type netTun struct {
    30  	ep             *channel.Endpoint
    31  	stack          *stack.Stack
    32  	events         chan tun.Event
    33  	incomingPacket chan *buffer.View
    34  	mtu            int
    35  	hasV4, hasV6   bool
    36  }
    37  
    38  type Net netTun
    39  
    40  func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) {
    41  	opts := stack.Options{
    42  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    43  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
    44  		HandleLocal:        !promiscuousMode,
    45  	}
    46  	dev := &netTun{
    47  		ep:             channel.New(1024, uint32(mtu), ""),
    48  		stack:          stack.New(opts),
    49  		events:         make(chan tun.Event, 1),
    50  		incomingPacket: make(chan *buffer.View),
    51  		mtu:            mtu,
    52  	}
    53  	dev.ep.AddNotify(dev)
    54  	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
    55  	if tcpipErr != nil {
    56  		return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
    57  	}
    58  	for _, ip := range localAddresses {
    59  		var protoNumber tcpip.NetworkProtocolNumber
    60  		if ip.Is4() {
    61  			protoNumber = ipv4.ProtocolNumber
    62  		} else if ip.Is6() {
    63  			protoNumber = ipv6.ProtocolNumber
    64  		}
    65  		protoAddr := tcpip.ProtocolAddress{
    66  			Protocol:          protoNumber,
    67  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
    68  		}
    69  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    70  		if tcpipErr != nil {
    71  			return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    72  		}
    73  		if ip.Is4() {
    74  			dev.hasV4 = true
    75  		} else if ip.Is6() {
    76  			dev.hasV6 = true
    77  		}
    78  	}
    79  	if dev.hasV4 {
    80  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
    81  	}
    82  	if dev.hasV6 {
    83  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
    84  	}
    85  	if promiscuousMode {
    86  		// enable promiscuous mode to handle all packets processed by netstack
    87  		dev.stack.SetPromiscuousMode(1, true)
    88  		dev.stack.SetSpoofing(1, true)
    89  	}
    90  
    91  	opt := tcpip.CongestionControlOption("cubic")
    92  	if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
    93  		return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
    94  	}
    95  
    96  	dev.events <- tun.EventUp
    97  	return dev, (*Net)(dev), dev.stack, nil
    98  }
    99  
   100  // BatchSize implements tun.Device
   101  func (tun *netTun) BatchSize() int {
   102  	return 1
   103  }
   104  
   105  // Name implements tun.Device
   106  func (tun *netTun) Name() (string, error) {
   107  	return "go", nil
   108  }
   109  
   110  // File implements tun.Device
   111  func (tun *netTun) File() *os.File {
   112  	return nil
   113  }
   114  
   115  // Events implements tun.Device
   116  func (tun *netTun) Events() <-chan tun.Event {
   117  	return tun.events
   118  }
   119  
   120  // Read implements tun.Device
   121  
   122  func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
   123  	view, ok := <-tun.incomingPacket
   124  	if !ok {
   125  		return 0, os.ErrClosed
   126  	}
   127  
   128  	n, err := view.Read(buf[0][offset:])
   129  	if err != nil {
   130  		return 0, err
   131  	}
   132  	sizes[0] = n
   133  	return 1, nil
   134  }
   135  
   136  // Write implements tun.Device
   137  func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
   138  	for _, buf := range buf {
   139  		packet := buf[offset:]
   140  		if len(packet) == 0 {
   141  			continue
   142  		}
   143  
   144  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
   145  		switch packet[0] >> 4 {
   146  		case 4:
   147  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   148  		case 6:
   149  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   150  		default:
   151  			return 0, syscall.EAFNOSUPPORT
   152  		}
   153  	}
   154  	return len(buf), nil
   155  }
   156  
   157  // WriteNotify implements channel.Notification
   158  func (tun *netTun) WriteNotify() {
   159  	pkt := tun.ep.Read()
   160  	if pkt.IsNil() {
   161  		return
   162  	}
   163  
   164  	view := pkt.ToView()
   165  	pkt.DecRef()
   166  
   167  	tun.incomingPacket <- view
   168  }
   169  
   170  // Flush  implements tun.Device
   171  func (tun *netTun) Flush() error {
   172  	return nil
   173  }
   174  
   175  // Close implements tun.Device
   176  func (tun *netTun) Close() error {
   177  	tun.stack.RemoveNIC(1)
   178  
   179  	if tun.events != nil {
   180  		close(tun.events)
   181  	}
   182  
   183  	tun.ep.Close()
   184  
   185  	if tun.incomingPacket != nil {
   186  		close(tun.incomingPacket)
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  // MTU  implements tun.Device
   193  func (tun *netTun) MTU() (int, error) {
   194  	return tun.mtu, nil
   195  }
   196  
   197  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   198  	var protoNumber tcpip.NetworkProtocolNumber
   199  	if endpoint.Addr().Is4() {
   200  		protoNumber = ipv4.ProtocolNumber
   201  	} else {
   202  		protoNumber = ipv6.ProtocolNumber
   203  	}
   204  	return tcpip.FullAddress{
   205  		NIC:  1,
   206  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
   207  		Port: endpoint.Port(),
   208  	}, protoNumber
   209  }
   210  
   211  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   212  	fa, pn := convertToFullAddr(addr)
   213  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   214  }
   215  
   216  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   217  	var lfa, rfa *tcpip.FullAddress
   218  	var pn tcpip.NetworkProtocolNumber
   219  	if laddr.IsValid() || laddr.Port() > 0 {
   220  		var addr tcpip.FullAddress
   221  		addr, pn = convertToFullAddr(laddr)
   222  		lfa = &addr
   223  	}
   224  	if raddr.IsValid() || raddr.Port() > 0 {
   225  		var addr tcpip.FullAddress
   226  		addr, pn = convertToFullAddr(raddr)
   227  		rfa = &addr
   228  	}
   229  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   230  }