github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/tun.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"runtime"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/xmplusdev/xray-core/common/log"
    16  	xnet "github.com/xmplusdev/xray-core/common/net"
    17  	"github.com/xmplusdev/xray-core/proxy/wireguard/gvisortun"
    18  	"gvisor.dev/gvisor/pkg/tcpip"
    19  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    20  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    21  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    22  	"gvisor.dev/gvisor/pkg/waiter"
    23  
    24  	"golang.zx2c4.com/wireguard/conn"
    25  	"golang.zx2c4.com/wireguard/device"
    26  	"golang.zx2c4.com/wireguard/tun"
    27  )
    28  
    29  type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
    30  
    31  type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
    32  
    33  type Tunnel interface {
    34  	BuildDevice(ipc string, bind conn.Bind) error
    35  	DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
    36  	DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error)
    37  	Close() error
    38  }
    39  
    40  type tunnel struct {
    41  	tun    tun.Device
    42  	device *device.Device
    43  	rw     sync.Mutex
    44  }
    45  
    46  func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) {
    47  	t.rw.Lock()
    48  	defer t.rw.Unlock()
    49  
    50  	if t.device != nil {
    51  		return errors.New("device is already initialized")
    52  	}
    53  
    54  	logger := &device.Logger{
    55  		Verbosef: func(format string, args ...any) {
    56  			log.Record(&log.GeneralMessage{
    57  				Severity: log.Severity_Debug,
    58  				Content:  fmt.Sprintf(format, args...),
    59  			})
    60  		},
    61  		Errorf: func(format string, args ...any) {
    62  			log.Record(&log.GeneralMessage{
    63  				Severity: log.Severity_Error,
    64  				Content:  fmt.Sprintf(format, args...),
    65  			})
    66  		},
    67  	}
    68  
    69  	t.device = device.NewDevice(t.tun, bind, logger)
    70  	if err = t.device.IpcSet(ipc); err != nil {
    71  		return err
    72  	}
    73  	if err = t.device.Up(); err != nil {
    74  		return err
    75  	}
    76  	return nil
    77  }
    78  
    79  func (t *tunnel) Close() (err error) {
    80  	t.rw.Lock()
    81  	defer t.rw.Unlock()
    82  
    83  	if t.device == nil {
    84  		return nil
    85  	}
    86  
    87  	t.device.Close()
    88  	t.device = nil
    89  	err = t.tun.Close()
    90  	t.tun = nil
    91  	return nil
    92  }
    93  
    94  func CalculateInterfaceName(name string) (tunName string) {
    95  	if runtime.GOOS == "darwin" {
    96  		tunName = "utun"
    97  	} else if name != "" {
    98  		tunName = name
    99  	} else {
   100  		tunName = "tun"
   101  	}
   102  	interfaces, err := net.Interfaces()
   103  	if err != nil {
   104  		return
   105  	}
   106  	var tunIndex int
   107  	for _, netInterface := range interfaces {
   108  		if strings.HasPrefix(netInterface.Name, tunName) {
   109  			index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
   110  			if parseErr == nil {
   111  				tunIndex = int(index) + 1
   112  			}
   113  		}
   114  	}
   115  	tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
   116  	return
   117  }
   118  
   119  var _ Tunnel = (*gvisorNet)(nil)
   120  
   121  type gvisorNet struct {
   122  	tunnel
   123  	net *gvisortun.Net
   124  }
   125  
   126  func (g *gvisorNet) Close() error {
   127  	return g.tunnel.Close()
   128  }
   129  
   130  func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
   131  	net.Conn, error,
   132  ) {
   133  	return g.net.DialContextTCPAddrPort(ctx, addr)
   134  }
   135  
   136  func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
   137  	return g.net.DialUDPAddrPort(laddr, raddr)
   138  }
   139  
   140  func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
   141  	out := &gvisorNet{}
   142  	tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	if handler != nil {
   148  		// handler is only used for promiscuous mode
   149  		// capture all packets and send to handler
   150  
   151  		tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
   152  			go func(r *tcp.ForwarderRequest) {
   153  				var (
   154  					wq waiter.Queue
   155  					id = r.ID()
   156  				)
   157  
   158  				// Perform a TCP three-way handshake.
   159  				ep, err := r.CreateEndpoint(&wq)
   160  				if err != nil {
   161  					newError(err.String()).AtError().WriteToLog()
   162  					r.Complete(true)
   163  					return
   164  				}
   165  				r.Complete(false)
   166  				defer ep.Close()
   167  
   168  				// enable tcp keep-alive to prevent hanging connections
   169  				ep.SocketOptions().SetKeepAlive(true)
   170  
   171  				// local address is actually destination
   172  				handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
   173  			}(r)
   174  		})
   175  		stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
   176  
   177  		udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
   178  			go func(r *udp.ForwarderRequest) {
   179  				var (
   180  					wq waiter.Queue
   181  					id = r.ID()
   182  				)
   183  
   184  				ep, err := r.CreateEndpoint(&wq)
   185  				if err != nil {
   186  					newError(err.String()).AtError().WriteToLog()
   187  					return
   188  				}
   189  				defer ep.Close()
   190  
   191  				// prevents hanging connections and ensure timely release
   192  				ep.SocketOptions().SetLinger(tcpip.LingerOption{
   193  					Enabled: true,
   194  					Timeout: 15 * time.Second,
   195  				})
   196  
   197  				handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
   198  			}(r)
   199  		})
   200  		stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
   201  	}
   202  
   203  	out.tun, out.net = tun, n
   204  	return out, nil
   205  }