github.com/xraypb/xray-core@v1.6.6/proxy/wireguard/wireguard.go (about)

     1  /*
     2  
     3  Some of codes are copied from https://github.com/octeep/wireproxy, license below.
     4  
     5  Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
     6  
     7  Permission to use, copy, modify, and distribute this software for any
     8  purpose with or without fee is hereby granted, provided that the above
     9  copyright notice and this permission notice appear in all copies.
    10  
    11  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
    12  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
    13  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
    14  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
    15  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
    16  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
    17  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
    18  
    19  */
    20  
    21  package wireguard
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"net/netip"
    28  	"strings"
    29  
    30  	"github.com/sagernet/wireguard-go/device"
    31  	"github.com/xraypb/xray-core/common"
    32  	"github.com/xraypb/xray-core/common/buf"
    33  	"github.com/xraypb/xray-core/common/log"
    34  	"github.com/xraypb/xray-core/common/net"
    35  	"github.com/xraypb/xray-core/common/protocol"
    36  	"github.com/xraypb/xray-core/common/session"
    37  	"github.com/xraypb/xray-core/common/signal"
    38  	"github.com/xraypb/xray-core/common/task"
    39  	"github.com/xraypb/xray-core/core"
    40  	"github.com/xraypb/xray-core/features/dns"
    41  	"github.com/xraypb/xray-core/features/policy"
    42  	"github.com/xraypb/xray-core/transport"
    43  	"github.com/xraypb/xray-core/transport/internet"
    44  )
    45  
    46  // Handler is an outbound connection that silently swallow the entire payload.
    47  type Handler struct {
    48  	conf          *DeviceConfig
    49  	net           *Net
    50  	bind          *netBindClient
    51  	policyManager policy.Manager
    52  	dns           dns.Client
    53  	// cached configuration
    54  	ipc       string
    55  	endpoints []netip.Addr
    56  }
    57  
    58  // New creates a new wireguard handler.
    59  func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
    60  	v := core.MustFromContext(ctx)
    61  
    62  	endpoints, err := parseEndpoints(conf)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return &Handler{
    68  		conf:          conf,
    69  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    70  		dns:           v.GetFeature(dns.ClientType()).(dns.Client),
    71  		ipc:           createIPCRequest(conf),
    72  		endpoints:     endpoints,
    73  	}, nil
    74  }
    75  
    76  // Process implements OutboundHandler.Dispatch().
    77  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    78  	if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
    79  		log.Record(&log.GeneralMessage{
    80  			Severity: log.Severity_Info,
    81  			Content:  "switching dialer",
    82  		})
    83  		// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
    84  		bind := &netBindClient{
    85  			dialer:  dialer,
    86  			workers: int(h.conf.NumWorkers),
    87  			dns:     h.dns,
    88  		}
    89  
    90  		net, err := h.makeVirtualTun(bind)
    91  		if err != nil {
    92  			bind.Close()
    93  			return newError("failed to create virtual tun interface").Base(err)
    94  		}
    95  
    96  		h.net = net
    97  		if h.bind != nil {
    98  			h.bind.Close()
    99  		}
   100  		h.bind = bind
   101  	}
   102  
   103  	outbound := session.OutboundFromContext(ctx)
   104  	if outbound == nil || !outbound.Target.IsValid() {
   105  		return newError("target not specified")
   106  	}
   107  	// Destination of the inner request.
   108  	destination := outbound.Target
   109  	command := protocol.RequestCommandTCP
   110  	if destination.Network == net.Network_UDP {
   111  		command = protocol.RequestCommandUDP
   112  	}
   113  
   114  	// resolve dns
   115  	addr := destination.Address
   116  	if addr.Family().IsDomain() {
   117  		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   118  			IPv4Enable: h.net.HasV4(),
   119  			IPv6Enable: h.net.HasV6(),
   120  		})
   121  		if err != nil {
   122  			return newError("failed to lookup DNS").Base(err)
   123  		} else if len(ips) == 0 {
   124  			return dns.ErrEmptyResponse
   125  		}
   126  		addr = net.IPAddress(ips[0])
   127  	}
   128  
   129  	p := h.policyManager.ForLevel(0)
   130  
   131  	ctx, cancel := context.WithCancel(ctx)
   132  	timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle)
   133  	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
   134  
   135  	var requestFunc func() error
   136  	var responseFunc func() error
   137  
   138  	if command == protocol.RequestCommandTCP {
   139  		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
   140  		if err != nil {
   141  			return newError("failed to create TCP connection").Base(err)
   142  		}
   143  
   144  		requestFunc = func() error {
   145  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   146  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   147  		}
   148  		responseFunc = func() error {
   149  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   150  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   151  		}
   152  	} else if command == protocol.RequestCommandUDP {
   153  		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
   154  		if err != nil {
   155  			return newError("failed to create UDP connection").Base(err)
   156  		}
   157  
   158  		requestFunc = func() error {
   159  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   160  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   161  		}
   162  		responseFunc = func() error {
   163  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   164  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   165  		}
   166  	}
   167  
   168  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   169  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   170  		return newError("connection ends").Base(err)
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  // serialize the config into an IPC request
   177  func createIPCRequest(conf *DeviceConfig) string {
   178  	var request bytes.Buffer
   179  
   180  	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
   181  
   182  	for _, peer := range conf.Peers {
   183  		request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
   184  			peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
   185  
   186  		for _, ip := range peer.AllowedIps {
   187  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   188  		}
   189  	}
   190  
   191  	return request.String()[:request.Len()]
   192  }
   193  
   194  // convert endpoint string to netip.Addr
   195  func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
   196  	endpoints := make([]netip.Addr, len(conf.Endpoint))
   197  	for i, str := range conf.Endpoint {
   198  		var addr netip.Addr
   199  		if strings.Contains(str, "/") {
   200  			prefix, err := netip.ParsePrefix(str)
   201  			if err != nil {
   202  				return nil, err
   203  			}
   204  			addr = prefix.Addr()
   205  			if prefix.Bits() != addr.BitLen() {
   206  				return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
   207  			}
   208  		} else {
   209  			var err error
   210  			addr, err = netip.ParseAddr(str)
   211  			if err != nil {
   212  				return nil, err
   213  			}
   214  		}
   215  		endpoints[i] = addr
   216  	}
   217  
   218  	return endpoints, nil
   219  }
   220  
   221  // creates a tun interface on netstack given a configuration
   222  func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
   223  	tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	bind.dnsOption.IPv4Enable = tnet.HasV4()
   229  	bind.dnsOption.IPv6Enable = tnet.HasV6()
   230  
   231  	// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
   232  	dev := device.NewDevice(tun, bind, &device.Logger{
   233  		Verbosef: func(format string, args ...any) {
   234  			log.Record(&log.GeneralMessage{
   235  				Severity: log.Severity_Debug,
   236  				Content:  fmt.Sprintf(format, args...),
   237  			})
   238  		},
   239  		Errorf: func(format string, args ...any) {
   240  			log.Record(&log.GeneralMessage{
   241  				Severity: log.Severity_Error,
   242  				Content:  fmt.Sprintf(format, args...),
   243  			})
   244  		},
   245  	}, int(h.conf.NumWorkers))
   246  	err = dev.IpcSet(h.ipc)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	err = dev.Up()
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  
   256  	return tnet, nil
   257  }
   258  
   259  func init() {
   260  	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   261  		return New(ctx, config.(*DeviceConfig))
   262  	}))
   263  }