github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/wireguard.go (about)

     1  /*
     2  
     3  Some of codes are copied from https://github.com/octeep/wireproxy, license below.
     4  
     5  Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
     6  
     7  Permission to use, copy, modify, and distribute this software for any
     8  purpose with or without fee is hereby granted, provided that the above
     9  copyright notice and this permission notice appear in all copies.
    10  
    11  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
    12  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
    13  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
    14  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
    15  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
    16  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
    17  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
    18  
    19  */
    20  
    21  package wireguard
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"net/netip"
    28  	"strings"
    29  
    30  	"github.com/sagernet/wireguard-go/device"
    31  	"github.com/moqsien/xraycore/common"
    32  	"github.com/moqsien/xraycore/common/buf"
    33  	"github.com/moqsien/xraycore/common/log"
    34  	"github.com/moqsien/xraycore/common/net"
    35  	"github.com/moqsien/xraycore/common/protocol"
    36  	"github.com/moqsien/xraycore/common/session"
    37  	"github.com/moqsien/xraycore/common/signal"
    38  	"github.com/moqsien/xraycore/common/task"
    39  	"github.com/moqsien/xraycore/core"
    40  	"github.com/moqsien/xraycore/features/dns"
    41  	"github.com/moqsien/xraycore/features/policy"
    42  	"github.com/moqsien/xraycore/transport"
    43  	"github.com/moqsien/xraycore/transport/internet"
    44  )
    45  
    46  // Handler is an outbound connection that silently swallow the entire payload.
    47  type Handler struct {
    48  	conf          *DeviceConfig
    49  	net           *Net
    50  	bind          *netBindClient
    51  	policyManager policy.Manager
    52  	dns           dns.Client
    53  	// cached configuration
    54  	ipc       string
    55  	endpoints []netip.Addr
    56  }
    57  
    58  // New creates a new wireguard handler.
    59  func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
    60  	v := core.MustFromContext(ctx)
    61  
    62  	endpoints, err := parseEndpoints(conf)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return &Handler{
    68  		conf:          conf,
    69  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    70  		dns:           v.GetFeature(dns.ClientType()).(dns.Client),
    71  		ipc:           createIPCRequest(conf),
    72  		endpoints:     endpoints,
    73  	}, nil
    74  }
    75  
    76  // Process implements OutboundHandler.Dispatch().
    77  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    78  	if h.bind == nil || h.bind.dialer != dialer || h.net == nil {
    79  		log.Record(&log.GeneralMessage{
    80  			Severity: log.Severity_Info,
    81  			Content:  "switching dialer",
    82  		})
    83  		// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
    84  		bind := &netBindClient{
    85  			dialer:   dialer,
    86  			workers:  int(h.conf.NumWorkers),
    87  			dns:      h.dns,
    88  			reserved: h.conf.Reserved,
    89  		}
    90  
    91  		net, err := h.makeVirtualTun(bind)
    92  		if err != nil {
    93  			bind.Close()
    94  			return newError("failed to create virtual tun interface").Base(err)
    95  		}
    96  
    97  		h.net = net
    98  		if h.bind != nil {
    99  			h.bind.Close()
   100  		}
   101  		h.bind = bind
   102  	}
   103  
   104  	outbound := session.OutboundFromContext(ctx)
   105  	if outbound == nil || !outbound.Target.IsValid() {
   106  		return newError("target not specified")
   107  	}
   108  	// Destination of the inner request.
   109  	destination := outbound.Target
   110  	command := protocol.RequestCommandTCP
   111  	if destination.Network == net.Network_UDP {
   112  		command = protocol.RequestCommandUDP
   113  	}
   114  
   115  	// resolve dns
   116  	addr := destination.Address
   117  	if addr.Family().IsDomain() {
   118  		ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
   119  			IPv4Enable: h.net.HasV4(),
   120  			IPv6Enable: h.net.HasV6(),
   121  		})
   122  		if err != nil {
   123  			return newError("failed to lookup DNS").Base(err)
   124  		} else if len(ips) == 0 {
   125  			return dns.ErrEmptyResponse
   126  		}
   127  		addr = net.IPAddress(ips[0])
   128  	}
   129  
   130  	var newCtx context.Context
   131  	var newCancel context.CancelFunc
   132  	if session.TimeoutOnlyFromContext(ctx) {
   133  		newCtx, newCancel = context.WithCancel(context.Background())
   134  	}
   135  
   136  	p := h.policyManager.ForLevel(0)
   137  
   138  	ctx, cancel := context.WithCancel(ctx)
   139  	timer := signal.CancelAfterInactivity(ctx, func() {
   140  		cancel()
   141  		if newCancel != nil {
   142  			newCancel()
   143  		}
   144  	}, p.Timeouts.ConnectionIdle)
   145  	addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
   146  
   147  	var requestFunc func() error
   148  	var responseFunc func() error
   149  
   150  	if command == protocol.RequestCommandTCP {
   151  		conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
   152  		if err != nil {
   153  			return newError("failed to create TCP connection").Base(err)
   154  		}
   155  		defer conn.Close()
   156  
   157  		requestFunc = func() error {
   158  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   159  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   160  		}
   161  		responseFunc = func() error {
   162  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   163  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   164  		}
   165  	} else if command == protocol.RequestCommandUDP {
   166  		conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
   167  		if err != nil {
   168  			return newError("failed to create UDP connection").Base(err)
   169  		}
   170  		defer conn.Close()
   171  
   172  		requestFunc = func() error {
   173  			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
   174  			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
   175  		}
   176  		responseFunc = func() error {
   177  			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
   178  			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   179  		}
   180  	}
   181  
   182  	if newCtx != nil {
   183  		ctx = newCtx
   184  	}
   185  
   186  	responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
   187  	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
   188  		common.Interrupt(link.Reader)
   189  		common.Interrupt(link.Writer)
   190  		return newError("connection ends").Base(err)
   191  	}
   192  
   193  	return nil
   194  }
   195  
   196  // serialize the config into an IPC request
   197  func createIPCRequest(conf *DeviceConfig) string {
   198  	var request bytes.Buffer
   199  
   200  	request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey))
   201  
   202  	for _, peer := range conf.Peers {
   203  		request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n",
   204  			peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey))
   205  
   206  		for _, ip := range peer.AllowedIps {
   207  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   208  		}
   209  	}
   210  
   211  	return request.String()[:request.Len()]
   212  }
   213  
   214  // convert endpoint string to netip.Addr
   215  func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) {
   216  	endpoints := make([]netip.Addr, len(conf.Endpoint))
   217  	for i, str := range conf.Endpoint {
   218  		var addr netip.Addr
   219  		if strings.Contains(str, "/") {
   220  			prefix, err := netip.ParsePrefix(str)
   221  			if err != nil {
   222  				return nil, err
   223  			}
   224  			addr = prefix.Addr()
   225  			if prefix.Bits() != addr.BitLen() {
   226  				return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6")
   227  			}
   228  		} else {
   229  			var err error
   230  			addr, err = netip.ParseAddr(str)
   231  			if err != nil {
   232  				return nil, err
   233  			}
   234  		}
   235  		endpoints[i] = addr
   236  	}
   237  
   238  	return endpoints, nil
   239  }
   240  
   241  // creates a tun interface on netstack given a configuration
   242  func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) {
   243  	tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu))
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	bind.dnsOption.IPv4Enable = tnet.HasV4()
   249  	bind.dnsOption.IPv6Enable = tnet.HasV6()
   250  
   251  	// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
   252  	dev := device.NewDevice(context.Background(), tun, bind, &device.Logger{
   253  		Verbosef: func(format string, args ...any) {
   254  			log.Record(&log.GeneralMessage{
   255  				Severity: log.Severity_Debug,
   256  				Content:  fmt.Sprintf(format, args...),
   257  			})
   258  		},
   259  		Errorf: func(format string, args ...any) {
   260  			log.Record(&log.GeneralMessage{
   261  				Severity: log.Severity_Error,
   262  				Content:  fmt.Sprintf(format, args...),
   263  			})
   264  		},
   265  	}, int(h.conf.NumWorkers))
   266  	err = dev.IpcSet(h.ipc)
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  
   271  	err = dev.Up()
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	return tnet, nil
   277  }
   278  
   279  func init() {
   280  	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   281  		return New(ctx, config.(*DeviceConfig))
   282  	}))
   283  }