github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/wireguard.go (about)

     1  /*
     2  
     3  Some of codes are copied from https://github.com/octeep/wireproxy, license below.
     4  
     5  Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
     6  
     7  Permission to use, copy, modify, and distribute this software for any
     8  purpose with or without fee is hereby granted, provided that the above
     9  copyright notice and this permission notice appear in all copies.
    10  
    11  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
    12  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
    13  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
    14  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
    15  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
    16  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
    17  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
    18  
    19  */
    20  
    21  package wireguard
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"net/netip"
    28  	"strings"
    29  
    30  	"github.com/sagernet/wireguard-go/device"
    31  	"github.com/xraypb/Xray-core/common"
    32  	"github.com/xraypb/Xray-core/common/buf"
    33  	"github.com/xraypb/Xray-core/common/log"
    34  	"github.com/xraypb/Xray-core/common/net"
    35  	"github.com/xraypb/Xray-core/common/protocol"
    36  	"github.com/xraypb/Xray-core/common/session"
    37  	"github.com/xraypb/Xray-core/common/signal"
    38  	"github.com/xraypb/Xray-core/common/task"
    39  	"github.com/xraypb/Xray-core/core"
    40  	"github.com/xraypb/Xray-core/features/dns"
    41  	"github.com/xraypb/Xray-core/features/policy"
    42  	"github.com/xraypb/Xray-core/transport"
    43  	"github.com/xraypb/Xray-core/transport/internet"
    44  )
    45  
    46  // Handler is an outbound connection that silently swallow the entire payload.
    47  type Handler struct {
    48  	conf          *DeviceConfig
    49  	net           *Net
    50  	bind          *netBindClient
    51  	policyManager policy.Manager
    52  	dns           dns.Client
    53  	// cached configuration
    54  	ipc       string
    55  	endpoints []netip.Addr
    56  }
    57  
    58  // New creates a new wireguard handler.
    59  func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
    60  	v := core.MustFromContext(ctx)
    61  
    62  	endpoints, err := parseEndpoints(conf)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return &Handler{
    68  		conf:          conf,
    69  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    70  		dns:           v.GetFeature(dns.ClientType()).(dns.Client),
    71  		ipc:           createIPCRequest(conf),
    72  		endpoints:     endpoints,
    73  	}, nil
    74  }
    75  
    76  // Process implements OutboundHandler.Dispatch().
    77  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    78  	if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
    79  		log.Record(&log.GeneralMessage{
    80  			Severity: log.Severity_Info,
    81  			Content:  "switching dialer",
    82  		})
    83  		// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
    84  		bind := &netBindClient{
    85  			dialer:   dialer,
    86  			workers:  int(h.conf.NumWorkers),
    87  			dns:      h.dns,
    88  			reserved: h.conf.Reserved,
    89  		}
    90  
    91  		net, err := h.makeVirtualTun(bind)
    92  		if err != nil {
    93  			bind.Close()
    94  			return newError("failed to create virtual tun interface").Base(err)
    95  		}
    96  
    97  		h.net = net
    98  		if h.bind != nil {
    99  			h.bind.Close()
   100  		}
   101  		h.bind = bind
   102  	}
   103  
   104  	outbound := session.OutboundFromContext(ctx)
   105  	if outbound == nil || !outbound.Target.IsValid() {
   106  		return newError("target not specified")
   107  	}
   108  	// Destination of the inner request.
   109  	destination := outbound.Target
   110  	command := protocol.RequestCommandTCP
   111  	if destination.Network == net.Network_UDP {
   112  		command = protocol.RequestCommandUDP
   113  	}
   114  
   115  	// resolve dns
   116  	addr := destination.Address
   117  	if addr.Family().IsDomain() {
   118  		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   119  			IPv4Enable: h.net.HasV4(),
   120  			IPv6Enable: h.net.HasV6(),
   121  		})
   122  		if err != nil {
   123  			return newError("failed to lookup DNS").Base(err)
   124  		} else if len(ips) == 0 {
   125  			return dns.ErrEmptyResponse
   126  		}
   127  		addr = net.IPAddress(ips[0])
   128  	}
   129  
   130  	var newCtx context.Context
   131  	var newCancel context.CancelFunc
   132  	if session.TimeoutOnlyFromContext(ctx) {
   133  		newCtx, newCancel = context.WithCancel(context.Background())
   134  	}
   135  
   136  	p := h.policyManager.ForLevel(0)
   137  
   138  	ctx, cancel := context.WithCancel(ctx)
   139  	timer := signal.CancelAfterInactivity(ctx, func() {
   140  		cancel()
   141  		if newCancel != nil {
   142  			newCancel()
   143  		}
   144  	}, p.Timeouts.ConnectionIdle)
   145  	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
   146  
   147  	var requestFunc func() error
   148  	var responseFunc func() error
   149  
   150  	if command == protocol.RequestCommandTCP {
   151  		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
   152  		if err != nil {
   153  			return newError("failed to create TCP connection").Base(err)
   154  		}
   155  
   156  		requestFunc = func() error {
   157  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   158  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   159  		}
   160  		responseFunc = func() error {
   161  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   162  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   163  		}
   164  	} else if command == protocol.RequestCommandUDP {
   165  		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
   166  		if err != nil {
   167  			return newError("failed to create UDP connection").Base(err)
   168  		}
   169  
   170  		requestFunc = func() error {
   171  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   172  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   173  		}
   174  		responseFunc = func() error {
   175  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   176  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   177  		}
   178  	}
   179  
   180  	if newCtx != nil {
   181  		ctx = newCtx
   182  	}
   183  
   184  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   185  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   186  		return newError("connection ends").Base(err)
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  // serialize the config into an IPC request
   193  func createIPCRequest(conf *DeviceConfig) string {
   194  	var request bytes.Buffer
   195  
   196  	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
   197  
   198  	for _, peer := range conf.Peers {
   199  		request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
   200  			peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
   201  
   202  		for _, ip := range peer.AllowedIps {
   203  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   204  		}
   205  	}
   206  
   207  	return request.String()[:request.Len()]
   208  }
   209  
   210  // convert endpoint string to netip.Addr
   211  func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
   212  	endpoints := make([]netip.Addr, len(conf.Endpoint))
   213  	for i, str := range conf.Endpoint {
   214  		var addr netip.Addr
   215  		if strings.Contains(str, "/") {
   216  			prefix, err := netip.ParsePrefix(str)
   217  			if err != nil {
   218  				return nil, err
   219  			}
   220  			addr = prefix.Addr()
   221  			if prefix.Bits() != addr.BitLen() {
   222  				return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
   223  			}
   224  		} else {
   225  			var err error
   226  			addr, err = netip.ParseAddr(str)
   227  			if err != nil {
   228  				return nil, err
   229  			}
   230  		}
   231  		endpoints[i] = addr
   232  	}
   233  
   234  	return endpoints, nil
   235  }
   236  
   237  // creates a tun interface on netstack given a configuration
   238  func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
   239  	tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	bind.dnsOption.IPv4Enable = tnet.HasV4()
   245  	bind.dnsOption.IPv6Enable = tnet.HasV6()
   246  
   247  	// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
   248  	dev := device.NewDevice(tun, bind, &device.Logger{
   249  		Verbosef: func(format string, args ...any) {
   250  			log.Record(&log.GeneralMessage{
   251  				Severity: log.Severity_Debug,
   252  				Content:  fmt.Sprintf(format, args...),
   253  			})
   254  		},
   255  		Errorf: func(format string, args ...any) {
   256  			log.Record(&log.GeneralMessage{
   257  				Severity: log.Severity_Error,
   258  				Content:  fmt.Sprintf(format, args...),
   259  			})
   260  		},
   261  	}, int(h.conf.NumWorkers))
   262  	err = dev.IpcSet(h.ipc)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	err = dev.Up()
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	return tnet, nil
   273  }
   274  
   275  func init() {
   276  	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   277  		return New(ctx, config.(*DeviceConfig))
   278  	}))
   279  }