github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/tun_linux.go (about)

     1  //go:build linux && !android
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/netip"
    11  	"os"
    12  
    13  	"golang.org/x/sys/unix"
    14  
    15  	"github.com/sagernet/sing/common/control"
    16  	"github.com/vishvananda/netlink"
    17  	wgtun "golang.zx2c4.com/wireguard/tun"
    18  )
    19  
    20  type deviceNet struct {
    21  	tunnel
    22  	dialer net.Dialer
    23  
    24  	handle    *netlink.Handle
    25  	linkAddrs []netlink.Addr
    26  	routes    []*netlink.Route
    27  	rules     []*netlink.Rule
    28  }
    29  
    30  func newDeviceNet(interfaceName string) *deviceNet {
    31  	var dialer net.Dialer
    32  	bindControl := control.BindToInterface(control.DefaultInterfaceFinder(), interfaceName, -1)
    33  	dialer.Control = control.Append(dialer.Control, bindControl)
    34  	return &deviceNet{dialer: dialer}
    35  }
    36  
    37  func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
    38  	net.Conn, error,
    39  ) {
    40  	return d.dialer.DialContext(ctx, "tcp", addr.String())
    41  }
    42  
    43  func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
    44  	dialer := d.dialer
    45  	dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())}
    46  	return dialer.DialContext(context.Background(), "udp", raddr.String())
    47  }
    48  
    49  func (d *deviceNet) Close() (err error) {
    50  	var errs []error
    51  	for _, rule := range d.rules {
    52  		if err = d.handle.RuleDel(rule); err != nil {
    53  			errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
    54  		}
    55  	}
    56  	for _, route := range d.routes {
    57  		if err = d.handle.RouteDel(route); err != nil {
    58  			errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
    59  		}
    60  	}
    61  	if err = d.tunnel.Close(); err != nil {
    62  		errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err))
    63  	}
    64  	if d.handle != nil {
    65  		d.handle.Close()
    66  		d.handle = nil
    67  	}
    68  	if len(errs) == 0 {
    69  		return nil
    70  	}
    71  	return errors.Join(errs...)
    72  }
    73  
    74  func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
    75  	if handler != nil {
    76  		return nil, newError("TODO: support promiscuous mode")
    77  	}
    78  
    79  	var v4, v6 *netip.Addr
    80  	for _, prefixes := range localAddresses {
    81  		if v4 == nil && prefixes.Is4() {
    82  			x := prefixes
    83  			v4 = &x
    84  		}
    85  		if v6 == nil && prefixes.Is6() {
    86  			x := prefixes
    87  			v6 = &x
    88  		}
    89  	}
    90  
    91  	writeSysctlZero := func(path string) error {
    92  		_, err := os.Stat(path)
    93  		if os.IsNotExist(err) {
    94  			return nil
    95  		}
    96  		if err != nil {
    97  			return err
    98  		}
    99  		return os.WriteFile(path, []byte("0"), 0o644)
   100  	}
   101  
   102  	// system configs.
   103  	if v4 != nil {
   104  		if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil {
   105  			return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
   106  		}
   107  	}
   108  	if v6 != nil {
   109  		if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil {
   110  			return nil, fmt.Errorf("failed to enable ipv6: %w", err)
   111  		}
   112  		if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil {
   113  			return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
   114  		}
   115  	}
   116  
   117  	n := CalculateInterfaceName("wg")
   118  	wgt, err := wgtun.CreateTUN(n, mtu)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	defer func() {
   123  		if err != nil {
   124  			_ = wgt.Close()
   125  		}
   126  	}()
   127  
   128  	// disable linux rp_filter for tunnel device to avoid packet drop.
   129  	// the operation require root privilege on container require '--privileged' flag.
   130  	if v4 != nil {
   131  		if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil {
   132  			return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
   133  		}
   134  	}
   135  	if v6 != nil {
   136  		if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil {
   137  			return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
   138  		}
   139  	}
   140  
   141  	ipv6TableIndex := 1023
   142  	if v6 != nil {
   143  		r := &netlink.Route{Table: ipv6TableIndex}
   144  		for {
   145  			routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V6, r, netlink.RT_FILTER_TABLE)
   146  			if len(routeList) == 0 || fErr != nil {
   147  				break
   148  			}
   149  			ipv6TableIndex--
   150  			if ipv6TableIndex < 0 {
   151  				return nil, fmt.Errorf("failed to find available ipv6 table index")
   152  			}
   153  		}
   154  	}
   155  
   156  	out := newDeviceNet(n)
   157  	out.handle, err = netlink.NewHandle()
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	defer func() {
   162  		if err != nil {
   163  			_ = out.Close()
   164  		}
   165  	}()
   166  
   167  	l, err := netlink.LinkByName(n)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	if v4 != nil {
   173  		addr := netlink.Addr{
   174  			IPNet: &net.IPNet{
   175  				IP:   v4.AsSlice(),
   176  				Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()),
   177  			},
   178  		}
   179  		out.linkAddrs = append(out.linkAddrs, addr)
   180  	}
   181  	if v6 != nil {
   182  		addr := netlink.Addr{
   183  			IPNet: &net.IPNet{
   184  				IP:   v6.AsSlice(),
   185  				Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()),
   186  			},
   187  		}
   188  		out.linkAddrs = append(out.linkAddrs, addr)
   189  
   190  		rt := &netlink.Route{
   191  			LinkIndex: l.Attrs().Index,
   192  			Dst: &net.IPNet{
   193  				IP:   net.IPv6zero,
   194  				Mask: net.CIDRMask(0, 128),
   195  			},
   196  			Table: ipv6TableIndex,
   197  		}
   198  		out.routes = append(out.routes, rt)
   199  
   200  		r := netlink.NewRule()
   201  		r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet
   202  		out.rules = append(out.rules, r)
   203  	}
   204  
   205  	for _, addr := range out.linkAddrs {
   206  		if err = out.handle.AddrAdd(l, &addr); err != nil {
   207  			return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
   208  		}
   209  	}
   210  	if err = out.handle.LinkSetMTU(l, mtu); err != nil {
   211  		return nil, err
   212  	}
   213  	if err = out.handle.LinkSetUp(l); err != nil {
   214  		return nil, err
   215  	}
   216  
   217  	for _, route := range out.routes {
   218  		if err = out.handle.RouteAdd(route); err != nil {
   219  			return nil, fmt.Errorf("failed to add route %s: %w", route, err)
   220  		}
   221  	}
   222  	for _, rule := range out.rules {
   223  		if err = out.handle.RuleAdd(rule); err != nil {
   224  			return nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
   225  		}
   226  	}
   227  	out.tun = wgt
   228  	return out, nil
   229  }
   230  
   231  func KernelTunSupported() bool {
   232  	// run a superuser permission check to check
   233  	// if the current user has the sufficient permission
   234  	// to create a tun device.
   235  
   236  	return unix.Geteuid() == 0 // 0 means root
   237  }