github.com/sagernet/sing-box@v1.2.7/outbound/wireguard.go (about)

     1  //go:build with_wireguard
     2  
     3  package outbound
     4  
     5  import (
     6  	"context"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"net"
    11  	"strings"
    12  
    13  	"github.com/sagernet/sing-box/adapter"
    14  	"github.com/sagernet/sing-box/common/dialer"
    15  	C "github.com/sagernet/sing-box/constant"
    16  	"github.com/sagernet/sing-box/log"
    17  	"github.com/sagernet/sing-box/option"
    18  	"github.com/sagernet/sing-box/transport/wireguard"
    19  	"github.com/sagernet/sing-tun"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/debug"
    22  	E "github.com/sagernet/sing/common/exceptions"
    23  	M "github.com/sagernet/sing/common/metadata"
    24  	N "github.com/sagernet/sing/common/network"
    25  	"github.com/sagernet/wireguard-go/device"
    26  )
    27  
    28  var (
    29  	_ adapter.Outbound                = (*WireGuard)(nil)
    30  	_ adapter.InterfaceUpdateListener = (*WireGuard)(nil)
    31  )
    32  
    33  type WireGuard struct {
    34  	myOutboundAdapter
    35  	bind      *wireguard.ClientBind
    36  	device    *device.Device
    37  	tunDevice wireguard.Device
    38  }
    39  
    40  func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) {
    41  	outbound := &WireGuard{
    42  		myOutboundAdapter: myOutboundAdapter{
    43  			protocol: C.TypeWireGuard,
    44  			network:  options.Network.Build(),
    45  			router:   router,
    46  			logger:   logger,
    47  			tag:      tag,
    48  		},
    49  	}
    50  	var reserved [3]uint8
    51  	if len(options.Reserved) > 0 {
    52  		if len(options.Reserved) != 3 {
    53  			return nil, E.New("invalid reserved value, required 3 bytes, got ", len(options.Reserved))
    54  		}
    55  		copy(reserved[:], options.Reserved)
    56  	}
    57  	peerAddr := options.ServerOptions.Build()
    58  	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr, reserved)
    59  	localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
    60  	if len(localPrefixes) == 0 {
    61  		return nil, E.New("missing local address")
    62  	}
    63  	var privateKey, peerPublicKey, preSharedKey string
    64  	{
    65  		bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
    66  		if err != nil {
    67  			return nil, E.Cause(err, "decode private key")
    68  		}
    69  		privateKey = hex.EncodeToString(bytes)
    70  	}
    71  	{
    72  		bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
    73  		if err != nil {
    74  			return nil, E.Cause(err, "decode peer public key")
    75  		}
    76  		peerPublicKey = hex.EncodeToString(bytes)
    77  	}
    78  	if options.PreSharedKey != "" {
    79  		bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
    80  		if err != nil {
    81  			return nil, E.Cause(err, "decode pre shared key")
    82  		}
    83  		preSharedKey = hex.EncodeToString(bytes)
    84  	}
    85  	ipcConf := "private_key=" + privateKey
    86  	ipcConf += "\npublic_key=" + peerPublicKey
    87  	ipcConf += "\nendpoint=" + peerAddr.String()
    88  	if preSharedKey != "" {
    89  		ipcConf += "\npreshared_key=" + preSharedKey
    90  	}
    91  	var has4, has6 bool
    92  	for _, address := range localPrefixes {
    93  		if address.Addr().Is4() {
    94  			has4 = true
    95  		} else {
    96  			has6 = true
    97  		}
    98  	}
    99  	if has4 {
   100  		ipcConf += "\nallowed_ip=0.0.0.0/0"
   101  	}
   102  	if has6 {
   103  		ipcConf += "\nallowed_ip=::/0"
   104  	}
   105  	mtu := options.MTU
   106  	if mtu == 0 {
   107  		mtu = 1408
   108  	}
   109  	var wireTunDevice wireguard.Device
   110  	var err error
   111  	if !options.SystemInterface && tun.WithGVisor {
   112  		wireTunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu)
   113  	} else {
   114  		wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu)
   115  	}
   116  	if err != nil {
   117  		return nil, E.Cause(err, "create WireGuard device")
   118  	}
   119  	wgDevice := device.NewDevice(wireTunDevice, outbound.bind, &device.Logger{
   120  		Verbosef: func(format string, args ...interface{}) {
   121  			logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
   122  		},
   123  		Errorf: func(format string, args ...interface{}) {
   124  			logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
   125  		},
   126  	}, options.Workers)
   127  	if debug.Enabled {
   128  		logger.Trace("created wireguard ipc conf: \n", ipcConf)
   129  	}
   130  	err = wgDevice.IpcSet(ipcConf)
   131  	if err != nil {
   132  		return nil, E.Cause(err, "setup wireguard")
   133  	}
   134  	outbound.device = wgDevice
   135  	outbound.tunDevice = wireTunDevice
   136  	return outbound, nil
   137  }
   138  
   139  func (w *WireGuard) InterfaceUpdated() error {
   140  	w.bind.Reset()
   141  	return nil
   142  }
   143  
   144  func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   145  	switch network {
   146  	case N.NetworkTCP:
   147  		w.logger.InfoContext(ctx, "outbound connection to ", destination)
   148  	case N.NetworkUDP:
   149  		w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   150  	}
   151  	if destination.IsFqdn() {
   152  		addrs, err := w.router.LookupDefault(ctx, destination.Fqdn)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		return N.DialSerial(ctx, w.tunDevice, network, destination, addrs)
   157  	}
   158  	return w.tunDevice.DialContext(ctx, network, destination)
   159  }
   160  
   161  func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   162  	w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   163  	return w.tunDevice.ListenPacket(ctx, destination)
   164  }
   165  
   166  func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   167  	return NewConnection(ctx, w, conn, metadata)
   168  }
   169  
   170  func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   171  	return NewPacketConnection(ctx, w, conn, metadata)
   172  }
   173  
   174  func (w *WireGuard) Start() error {
   175  	return w.tunDevice.Start()
   176  }
   177  
   178  func (w *WireGuard) Close() error {
   179  	if w.device != nil {
   180  		w.device.Close()
   181  	}
   182  	w.tunDevice.Close()
   183  	return nil
   184  }