github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/client.go (about)

     1  /*
     2  
     3  Some of codes are copied from https://github.com/octeep/wireproxy, license below.
     4  
     5  Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
     6  
     7  Permission to use, copy, modify, and distribute this software for any
     8  purpose with or without fee is hereby granted, provided that the above
     9  copyright notice and this permission notice appear in all copies.
    10  
    11  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
    12  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
    13  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
    14  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
    15  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
    16  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
    17  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
    18  
    19  */
    20  
    21  package wireguard
    22  
    23  import (
    24  	"context"
    25  	"fmt"
    26  	"net/netip"
    27  	"strings"
    28  	"sync"
    29  
    30  	"github.com/xmplusdev/xray-core/common"
    31  	"github.com/xmplusdev/xray-core/common/buf"
    32  	"github.com/xmplusdev/xray-core/common/dice"
    33  	"github.com/xmplusdev/xray-core/common/log"
    34  	"github.com/xmplusdev/xray-core/common/net"
    35  	"github.com/xmplusdev/xray-core/common/protocol"
    36  	"github.com/xmplusdev/xray-core/common/session"
    37  	"github.com/xmplusdev/xray-core/common/signal"
    38  	"github.com/xmplusdev/xray-core/common/task"
    39  	"github.com/xmplusdev/xray-core/core"
    40  	"github.com/xmplusdev/xray-core/features/dns"
    41  	"github.com/xmplusdev/xray-core/features/policy"
    42  	"github.com/xmplusdev/xray-core/transport"
    43  	"github.com/xmplusdev/xray-core/transport/internet"
    44  )
    45  
    46  // Handler is an outbound connection that silently swallow the entire payload.
    47  type Handler struct {
    48  	conf          *DeviceConfig
    49  	net           Tunnel
    50  	bind          *netBindClient
    51  	policyManager policy.Manager
    52  	dns           dns.Client
    53  	// cached configuration
    54  	endpoints        []netip.Addr
    55  	hasIPv4, hasIPv6 bool
    56  	wgLock           sync.Mutex
    57  }
    58  
    59  // New creates a new wireguard handler.
    60  func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
    61  	v := core.MustFromContext(ctx)
    62  
    63  	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	d := v.GetFeature(dns.ClientType()).(dns.Client)
    69  	return &Handler{
    70  		conf:          conf,
    71  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    72  		dns:           d,
    73  		endpoints:     endpoints,
    74  		hasIPv4:       hasIPv4,
    75  		hasIPv6:       hasIPv6,
    76  	}, nil
    77  }
    78  
    79  func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) {
    80  	h.wgLock.Lock()
    81  	defer h.wgLock.Unlock()
    82  
    83  	if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
    84  		return nil
    85  	}
    86  
    87  	log.Record(&log.GeneralMessage{
    88  		Severity: log.Severity_Info,
    89  		Content:  "switching dialer",
    90  	})
    91  
    92  	if h.net != nil {
    93  		_ = h.net.Close()
    94  		h.net = nil
    95  	}
    96  	if h.bind != nil {
    97  		_ = h.bind.Close()
    98  		h.bind = nil
    99  	}
   100  
   101  	// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
   102  	bind := &netBindClient{
   103  		netBind: netBind{
   104  			dns: h.dns,
   105  			dnsOption: dns.IPOption{
   106  				IPv4Enable: h.hasIPv4,
   107  				IPv6Enable: h.hasIPv6,
   108  			},
   109  			workers: int(h.conf.NumWorkers),
   110  		},
   111  		dialer:   dialer,
   112  		reserved: h.conf.Reserved,
   113  	}
   114  	defer func() {
   115  		if err != nil {
   116  			_ = bind.Close()
   117  		}
   118  	}()
   119  
   120  	h.net, err = h.makeVirtualTun(bind)
   121  	if err != nil {
   122  		return newError("failed to create virtual tun interface").Base(err)
   123  	}
   124  	h.bind = bind
   125  	return nil
   126  }
   127  
   128  // Process implements OutboundHandler.Dispatch().
   129  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   130  	outbound := session.OutboundFromContext(ctx)
   131  	if outbound == nil || !outbound.Target.IsValid() {
   132  		return newError("target not specified")
   133  	}
   134  	outbound.Name = "wireguard"
   135  	inbound := session.InboundFromContext(ctx)
   136  	if inbound != nil {
   137  		inbound.SetCanSpliceCopy(3)
   138  	}
   139  
   140  	if err := h.processWireGuard(dialer); err != nil {
   141  		return err
   142  	}
   143  
   144  	// Destination of the inner request.
   145  	destination := outbound.Target
   146  	command := protocol.RequestCommandTCP
   147  	if destination.Network == net.Network_UDP {
   148  		command = protocol.RequestCommandUDP
   149  	}
   150  
   151  	// resolve dns
   152  	addr := destination.Address
   153  	if addr.Family().IsDomain() {
   154  		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   155  			IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
   156  			IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
   157  		})
   158  		{ // Resolve fallback
   159  			if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
   160  				ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
   161  					IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
   162  					IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
   163  				})
   164  			}
   165  		}
   166  		if err != nil {
   167  			return newError("failed to lookup DNS").Base(err)
   168  		} else if len(ips) == 0 {
   169  			return dns.ErrEmptyResponse
   170  		}
   171  		addr = net.IPAddress(ips[dice.Roll(len(ips))])
   172  	}
   173  
   174  	var newCtx context.Context
   175  	var newCancel context.CancelFunc
   176  	if session.TimeoutOnlyFromContext(ctx) {
   177  		newCtx, newCancel = context.WithCancel(context.Background())
   178  	}
   179  
   180  	p := h.policyManager.ForLevel(0)
   181  
   182  	ctx, cancel := context.WithCancel(ctx)
   183  	timer := signal.CancelAfterInactivity(ctx, func() {
   184  		cancel()
   185  		if newCancel != nil {
   186  			newCancel()
   187  		}
   188  	}, p.Timeouts.ConnectionIdle)
   189  	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
   190  
   191  	var requestFunc func() error
   192  	var responseFunc func() error
   193  
   194  	if command == protocol.RequestCommandTCP {
   195  		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
   196  		if err != nil {
   197  			return newError("failed to create TCP connection").Base(err)
   198  		}
   199  		defer conn.Close()
   200  
   201  		requestFunc = func() error {
   202  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   203  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   204  		}
   205  		responseFunc = func() error {
   206  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   207  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   208  		}
   209  	} else if command == protocol.RequestCommandUDP {
   210  		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
   211  		if err != nil {
   212  			return newError("failed to create UDP connection").Base(err)
   213  		}
   214  		defer conn.Close()
   215  
   216  		requestFunc = func() error {
   217  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   218  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   219  		}
   220  		responseFunc = func() error {
   221  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   222  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   223  		}
   224  	}
   225  
   226  	if newCtx != nil {
   227  		ctx = newCtx
   228  	}
   229  
   230  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   231  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   232  		common.Interrupt(link.Reader)
   233  		common.Interrupt(link.Writer)
   234  		return newError("connection ends").Base(err)
   235  	}
   236  
   237  	return nil
   238  }
   239  
   240  // creates a tun interface on netstack given a configuration
   241  func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) {
   242  	t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  
   247  	bind.dnsOption.IPv4Enable = h.hasIPv4
   248  	bind.dnsOption.IPv6Enable = h.hasIPv6
   249  
   250  	if err = t.BuildDevice(h.createIPCRequest(bind, h.conf), bind); err != nil {
   251  		_ = t.Close()
   252  		return nil, err
   253  	}
   254  	return t, nil
   255  }
   256  
   257  // serialize the config into an IPC request
   258  func (h *Handler) createIPCRequest(bind *netBindClient, conf *DeviceConfig) string {
   259  	var request strings.Builder
   260  
   261  	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
   262  
   263  	if !conf.IsClient {
   264  		// placeholder, we'll handle actual port listening on Xray
   265  		request.WriteString("listen_port=1337\n")
   266  	}
   267  
   268  	for _, peer := range conf.Peers {
   269  		if peer.PublicKey != "" {
   270  			request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
   271  		}
   272  
   273  		if peer.PreSharedKey != "" {
   274  			request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
   275  		}
   276  
   277  		address, port, err := net.SplitHostPort(peer.Endpoint)
   278  		if err != nil {
   279  			newError("failed to split endpoint ", peer.Endpoint, " into address and port").AtError().WriteToLog()
   280  		}
   281  		addr := net.ParseAddress(address)
   282  		if addr.Family().IsDomain() {
   283  			dialerIp := bind.dialer.DestIpAddress()
   284  			if dialerIp != nil {
   285  				addr = net.ParseAddress(dialerIp.String())
   286  				newError("createIPCRequest use dialer dest ip: ", addr).WriteToLog()
   287  			} else {
   288  				ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   289  					IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
   290  					IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
   291  				})
   292  				{ // Resolve fallback
   293  					if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
   294  						ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
   295  							IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
   296  							IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
   297  						})
   298  					}
   299  				}
   300  				if err != nil {
   301  					newError("createIPCRequest failed to lookup DNS").Base(err).WriteToLog()
   302  				} else if len(ips) == 0 {
   303  					newError("createIPCRequest empty lookup DNS").WriteToLog()
   304  				} else {
   305  					addr = net.IPAddress(ips[dice.Roll(len(ips))])
   306  				}
   307  			}
   308  		}
   309  
   310  		if peer.Endpoint != "" {
   311  			request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port))
   312  		}
   313  
   314  		for _, ip := range peer.AllowedIps {
   315  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   316  		}
   317  
   318  		if peer.KeepAlive != 0 {
   319  			request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
   320  		}
   321  	}
   322  
   323  	return request.String()[:request.Len()]
   324  }