github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/client.go (about)

     1  /*
     2  
     3  Some of codes are copied from https://github.com/octeep/wireproxy, license below.
     4  
     5  Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
     6  
     7  Permission to use, copy, modify, and distribute this software for any
     8  purpose with or without fee is hereby granted, provided that the above
     9  copyright notice and this permission notice appear in all copies.
    10  
    11  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
    12  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
    13  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
    14  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
    15  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
    16  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
    17  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
    18  
    19  */
    20  
    21  package wireguard
    22  
    23  import (
    24  	"context"
    25  	"fmt"
    26  	"net/netip"
    27  	"strings"
    28  	"sync"
    29  
    30  	"github.com/xtls/xray-core/common"
    31  	"github.com/xtls/xray-core/common/buf"
    32  	"github.com/xtls/xray-core/common/dice"
    33  	"github.com/xtls/xray-core/common/log"
    34  	"github.com/xtls/xray-core/common/net"
    35  	"github.com/xtls/xray-core/common/protocol"
    36  	"github.com/xtls/xray-core/common/session"
    37  	"github.com/xtls/xray-core/common/signal"
    38  	"github.com/xtls/xray-core/common/task"
    39  	"github.com/xtls/xray-core/core"
    40  	"github.com/xtls/xray-core/features/dns"
    41  	"github.com/xtls/xray-core/features/policy"
    42  	"github.com/xtls/xray-core/transport"
    43  	"github.com/xtls/xray-core/transport/internet"
    44  )
    45  
    46  // Handler is an outbound connection that silently swallow the entire payload.
    47  type Handler struct {
    48  	conf          *DeviceConfig
    49  	net           Tunnel
    50  	bind          *netBindClient
    51  	policyManager policy.Manager
    52  	dns           dns.Client
    53  	// cached configuration
    54  	endpoints        []netip.Addr
    55  	hasIPv4, hasIPv6 bool
    56  	wgLock           sync.Mutex
    57  }
    58  
    59  // New creates a new wireguard handler.
    60  func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
    61  	v := core.MustFromContext(ctx)
    62  
    63  	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	d := v.GetFeature(dns.ClientType()).(dns.Client)
    69  	return &Handler{
    70  		conf:          conf,
    71  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    72  		dns:           d,
    73  		endpoints:     endpoints,
    74  		hasIPv4:       hasIPv4,
    75  		hasIPv6:       hasIPv6,
    76  	}, nil
    77  }
    78  
    79  func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
    80  	h.wgLock.Lock()
    81  	defer h.wgLock.Unlock()
    82  
    83  	if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
    84  		return nil
    85  	}
    86  
    87  	log.Record(&log.GeneralMessage{
    88  		Severity: log.Severity_Info,
    89  		Content:  "switching dialer",
    90  	})
    91  
    92  	if h.net != nil {
    93  		_ = h.net.Close()
    94  		h.net = nil
    95  	}
    96  	if h.bind != nil {
    97  		_ = h.bind.Close()
    98  		h.bind = nil
    99  	}
   100  
   101  	// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
   102  	bind := &netBindClient{
   103  		netBind: netBind{
   104  			dns: h.dns,
   105  			dnsOption: dns.IPOption{
   106  				IPv4Enable: h.hasIPv4,
   107  				IPv6Enable: h.hasIPv6,
   108  			},
   109  			workers: int(h.conf.NumWorkers),
   110  		},
   111  		dialer:   dialer,
   112  		reserved: h.conf.Reserved,
   113  	}
   114  	defer func() {
   115  		if err != nil {
   116  			_ = bind.Close()
   117  		}
   118  	}()
   119  
   120  	h.net, err = h.makeVirtualTun(bind)
   121  	if err != nil {
   122  		return newError("failed to create virtual tun interface").Base(err)
   123  	}
   124  	h.bind = bind
   125  	return nil
   126  }
   127  
   128  // Process implements OutboundHandler.Dispatch().
   129  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   130  	outbounds := session.OutboundsFromContext(ctx)
   131  	ob := outbounds[len(outbounds) - 1]
   132  	if !ob.Target.IsValid() {
   133  		return newError("target not specified")
   134  	}
   135  	ob.Name = "wireguard"
   136  	ob.CanSpliceCopy = 3
   137  
   138  	if err := h.processWireGuard(dialer); err != nil {
   139  		return err
   140  	}
   141  
   142  	// Destination of the inner request.
   143  	destination := ob.Target
   144  	command := protocol.RequestCommandTCP
   145  	if destination.Network == net.Network_UDP {
   146  		command = protocol.RequestCommandUDP
   147  	}
   148  
   149  	// resolve dns
   150  	addr := destination.Address
   151  	if addr.Family().IsDomain() {
   152  		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   153  			IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
   154  			IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
   155  		})
   156  		{ // Resolve fallback
   157  			if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
   158  				ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
   159  					IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
   160  					IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
   161  				})
   162  			}
   163  		}
   164  		if err != nil {
   165  			return newError("failed to lookup DNS").Base(err)
   166  		} else if len(ips) == 0 {
   167  			return dns.ErrEmptyResponse
   168  		}
   169  		addr = net.IPAddress(ips[dice.Roll(len(ips))])
   170  	}
   171  
   172  	var newCtx context.Context
   173  	var newCancel context.CancelFunc
   174  	if session.TimeoutOnlyFromContext(ctx) {
   175  		newCtx, newCancel = context.WithCancel(context.Background())
   176  	}
   177  
   178  	p := h.policyManager.ForLevel(0)
   179  
   180  	ctx, cancel := context.WithCancel(ctx)
   181  	timer := signal.CancelAfterInactivity(ctx, func() {
   182  		cancel()
   183  		if newCancel != nil {
   184  			newCancel()
   185  		}
   186  	}, p.Timeouts.ConnectionIdle)
   187  	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
   188  
   189  	var requestFunc func() error
   190  	var responseFunc func() error
   191  
   192  	if command == protocol.RequestCommandTCP {
   193  		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
   194  		if err != nil {
   195  			return newError("failed to create TCP connection").Base(err)
   196  		}
   197  		defer conn.Close()
   198  
   199  		requestFunc = func() error {
   200  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   201  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   202  		}
   203  		responseFunc = func() error {
   204  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   205  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   206  		}
   207  	} else if command == protocol.RequestCommandUDP {
   208  		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
   209  		if err != nil {
   210  			return newError("failed to create UDP connection").Base(err)
   211  		}
   212  		defer conn.Close()
   213  
   214  		requestFunc = func() error {
   215  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   216  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   217  		}
   218  		responseFunc = func() error {
   219  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   220  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   221  		}
   222  	}
   223  
   224  	if newCtx != nil {
   225  		ctx = newCtx
   226  	}
   227  
   228  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   229  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   230  		common.Interrupt(link.Reader)
   231  		common.Interrupt(link.Writer)
   232  		return newError("connection ends").Base(err)
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  // creates a tun interface on netstack given a configuration
   239  func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
   240  	t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
   241  	if err != nil {
   242  		return nil, err
   243  	}
   244  
   245  	bind.dnsOption.IPv4Enable = h.hasIPv4
   246  	bind.dnsOption.IPv6Enable = h.hasIPv6
   247  
   248  	if err = t.BuildDevice(h.createIPCRequest(bind, h.conf), bind); err != nil {
   249  		_ = t.Close()
   250  		return nil, err
   251  	}
   252  	return t, nil
   253  }
   254  
   255  // serialize the config into an IPC request
   256  func (h *Handler) createIPCRequest(bind *netBindClient, conf *DeviceConfig) string {
   257  	var request strings.Builder
   258  
   259  	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
   260  
   261  	if !conf.IsClient {
   262  		// placeholder, we'll handle actual port listening on Xray
   263  		request.WriteString("listen_port=1337\n")
   264  	}
   265  
   266  	for _, peer := range conf.Peers {
   267  		if peer.PublicKey != "" {
   268  			request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
   269  		}
   270  
   271  		if peer.PreSharedKey != "" {
   272  			request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
   273  		}
   274  
   275  		address, port, err := net.SplitHostPort(peer.Endpoint)
   276  		if err != nil {
   277  			newError("failed to split endpoint ", peer.Endpoint, " into address and port").AtError().WriteToLog()
   278  		}
   279  		addr := net.ParseAddress(address)
   280  		if addr.Family().IsDomain() {
   281  			dialerIp := bind.dialer.DestIpAddress()
   282  			if dialerIp != nil {
   283  				addr = net.ParseAddress(dialerIp.String())
   284  				newError("createIPCRequest use dialer dest ip: ", addr).WriteToLog()
   285  			} else {
   286  				ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   287  					IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
   288  					IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
   289  				})
   290  				{ // Resolve fallback
   291  					if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
   292  						ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
   293  							IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
   294  							IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
   295  						})
   296  					}
   297  				}
   298  				if err != nil {
   299  					newError("createIPCRequest failed to lookup DNS").Base(err).WriteToLog()
   300  				} else if len(ips) == 0 {
   301  					newError("createIPCRequest empty lookup DNS").WriteToLog()
   302  				} else {
   303  					addr = net.IPAddress(ips[dice.Roll(len(ips))])
   304  				}
   305  			}
   306  		}
   307  
   308  		if peer.Endpoint != "" {
   309  			request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port))
   310  		}
   311  
   312  		for _, ip := range peer.AllowedIps {
   313  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   314  		}
   315  
   316  		if peer.KeepAlive != 0 {
   317  			request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
   318  		}
   319  	}
   320  
   321  	return request.String()[:request.Len()]
   322  }