github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/mobile/mysterium/wireguard_connection_setup.go (about)

     1  /*
     2   * Copyright (C) 2018 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package mysterium
    19  
    20  import (
    21  	"bufio"
    22  	"context"
    23  	"encoding/json"
    24  	"fmt"
    25  	"net"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/pkg/errors"
    31  	"github.com/rs/zerolog/log"
    32  	"golang.zx2c4.com/wireguard/conn"
    33  	"golang.zx2c4.com/wireguard/device"
    34  	"golang.zx2c4.com/wireguard/tun"
    35  
    36  	"github.com/mysteriumnetwork/node/core/connection"
    37  	"github.com/mysteriumnetwork/node/core/connection/connectionstate"
    38  	"github.com/mysteriumnetwork/node/core/ip"
    39  	"github.com/mysteriumnetwork/node/services/wireguard"
    40  	wireguard_connection "github.com/mysteriumnetwork/node/services/wireguard/connection"
    41  	"github.com/mysteriumnetwork/node/services/wireguard/endpoint/userspace"
    42  	"github.com/mysteriumnetwork/node/services/wireguard/key"
    43  	"github.com/mysteriumnetwork/node/services/wireguard/wgcfg"
    44  )
    45  
    46  const (
    47  	// Taken from android-wireguard project
    48  	androidTunMtu = 1280
    49  )
    50  
    51  // WireguardTunnelSetup exposes api for caller to implement external tunnel setup
    52  type WireguardTunnelSetup interface {
    53  	NewTunnel()
    54  	AddTunnelAddress(ip string, prefixLen int)
    55  	AddRoute(route string, prefixLen int)
    56  	AddDNS(ip string)
    57  	SetBlocking(blocking bool)
    58  	Establish() (int, error)
    59  	SetMTU(mtu int)
    60  	Protect(socket int) error
    61  	SetSessionName(session string)
    62  }
    63  
    64  type wireGuardOptions struct {
    65  	statsUpdateInterval time.Duration
    66  	handshakeTimeout    time.Duration
    67  }
    68  
    69  // NewWireGuardConnection creates a new wireguard connection
    70  func NewWireGuardConnection(opts wireGuardOptions, device wireguardDevice, ipResolver ip.Resolver, handshakeWaiter wireguard_connection.HandshakeWaiter) (connection.Connection, error) {
    71  	privateKey, err := key.GeneratePrivateKey()
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return &wireguardConnection{
    77  		done:            make(chan struct{}),
    78  		stateCh:         make(chan connectionstate.State, 100),
    79  		opts:            opts,
    80  		device:          device,
    81  		privateKey:      privateKey,
    82  		ipResolver:      ipResolver,
    83  		handshakeWaiter: handshakeWaiter,
    84  	}, nil
    85  }
    86  
    87  type wireguardConnection struct {
    88  	ports           []int
    89  	closeOnce       sync.Once
    90  	done            chan struct{}
    91  	stateCh         chan connectionstate.State
    92  	opts            wireGuardOptions
    93  	privateKey      string
    94  	device          wireguardDevice
    95  	ipResolver      ip.Resolver
    96  	handshakeWaiter wireguard_connection.HandshakeWaiter
    97  }
    98  
    99  var _ connection.Connection = &wireguardConnection{}
   100  
   101  func (c *wireguardConnection) State() <-chan connectionstate.State {
   102  	return c.stateCh
   103  }
   104  
   105  func (c *wireguardConnection) Statistics() (connectionstate.Statistics, error) {
   106  	stats, err := c.device.Stats()
   107  	if err != nil {
   108  		return connectionstate.Statistics{}, err
   109  	}
   110  	return connectionstate.Statistics{
   111  		At:            time.Now(),
   112  		BytesSent:     stats.BytesSent,
   113  		BytesReceived: stats.BytesReceived,
   114  	}, nil
   115  }
   116  
   117  func (c *wireguardConnection) Reconnect(ctx context.Context, options connection.ConnectOptions) (err error) {
   118  	return c.Start(ctx, options)
   119  }
   120  
   121  func (c *wireguardConnection) Start(ctx context.Context, options connection.ConnectOptions) (err error) {
   122  	var config wireguard.ServiceConfig
   123  	err = json.Unmarshal(options.SessionConfig, &config)
   124  	if err != nil {
   125  		return errors.Wrap(err, "could not parse wireguard session config")
   126  	}
   127  
   128  	c.stateCh <- connectionstate.Connecting
   129  
   130  	defer func() {
   131  		if err != nil {
   132  			c.Stop()
   133  		}
   134  	}()
   135  
   136  	if options.ProviderNATConn != nil {
   137  		options.ProviderNATConn.Close()
   138  		config.LocalPort = options.ProviderNATConn.LocalAddr().(*net.UDPAddr).Port
   139  		config.Provider.Endpoint.Port = options.ProviderNATConn.RemoteAddr().(*net.UDPAddr).Port
   140  	}
   141  
   142  	if err = c.device.Start(c.privateKey, config, options.ChannelConn, options.Params.DNS); err != nil {
   143  		return errors.Wrap(err, "could not start device")
   144  	}
   145  
   146  	if err = c.handshakeWaiter.Wait(ctx, c.device.Stats, c.opts.handshakeTimeout, c.done); err != nil {
   147  		return errors.Wrap(err, "failed to handshake")
   148  	}
   149  
   150  	log.Debug().Msg("Connected successfully")
   151  	c.stateCh <- connectionstate.Connected
   152  	return nil
   153  }
   154  
   155  func (c *wireguardConnection) Stop() {
   156  	c.closeOnce.Do(func() {
   157  		c.stateCh <- connectionstate.Disconnecting
   158  		c.device.Stop()
   159  		c.stateCh <- connectionstate.NotConnected
   160  
   161  		close(c.stateCh)
   162  		close(c.done)
   163  	})
   164  }
   165  
   166  func (c *wireguardConnection) GetConfig() (connection.ConsumerConfig, error) {
   167  	if c.privateKey == "" {
   168  		return nil, errors.New("private key is missing")
   169  	}
   170  	publicKey, err := key.PrivateKeyToPublicKey(c.privateKey)
   171  	if err != nil {
   172  		return nil, errors.Wrap(err, "could not get public key from private key")
   173  	}
   174  
   175  	return wireguard.ConsumerConfig{
   176  		PublicKey: publicKey,
   177  		Ports:     c.ports,
   178  	}, nil
   179  }
   180  
   181  type wireguardDevice interface {
   182  	Start(privateKey string, config wireguard.ServiceConfig, channelConn *net.UDPConn, dns connection.DNSOption) error
   183  	Stop()
   184  	Stats() (wgcfg.Stats, error)
   185  }
   186  
   187  func newWireguardDevice(tunnelSetup WireguardTunnelSetup) wireguardDevice {
   188  	return &wireguardDeviceImpl{tunnelSetup: tunnelSetup}
   189  }
   190  
   191  type wireguardDeviceImpl struct {
   192  	tunnelSetup WireguardTunnelSetup
   193  
   194  	device *device.Device
   195  }
   196  
   197  func (w *wireguardDeviceImpl) Start(privateKey string, config wireguard.ServiceConfig, channelConn *net.UDPConn, dns connection.DNSOption) error {
   198  	log.Debug().Msg("Creating tunnel device")
   199  	tunDevice, err := w.newTunnDevice(w.tunnelSetup, config, dns)
   200  	if err != nil {
   201  		return errors.Wrap(err, "could not create tunnel device")
   202  	}
   203  
   204  	oldDevice := w.device
   205  	defer func() {
   206  		if oldDevice != nil {
   207  			oldDevice.Close()
   208  		}
   209  	}()
   210  
   211  	w.device = device.NewDevice(tunDevice, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, "[userspace-wg]"))
   212  
   213  	err = w.applyConfig(w.device, privateKey, config)
   214  	if err != nil {
   215  		return errors.Wrap(err, "could not setup device configuration")
   216  	}
   217  	w.device.Up()
   218  	socket, err := peekLookAtSocketFd4(w.device)
   219  	if err != nil {
   220  		return errors.Wrap(err, "could not get socket")
   221  	}
   222  	err = w.tunnelSetup.Protect(socket)
   223  	if err != nil {
   224  		return errors.Wrap(err, "could not protect socket")
   225  	}
   226  
   227  	// Exclude p2p channel traffic from VPN tunnel.
   228  	if channelConn != nil {
   229  		channelSocket, err := peekLookAtSocketFd4From(channelConn)
   230  		if err != nil {
   231  			return fmt.Errorf("could not get channel socket: %w", err)
   232  		}
   233  		err = w.tunnelSetup.Protect(channelSocket)
   234  		if err != nil {
   235  			return fmt.Errorf("could not protect p2p socket: %w", err)
   236  		}
   237  	}
   238  
   239  	return nil
   240  }
   241  
   242  func (w *wireguardDeviceImpl) Stop() {
   243  	if w.device != nil {
   244  		w.device.Close()
   245  	}
   246  }
   247  
   248  func (w *wireguardDeviceImpl) Stats() (wgcfg.Stats, error) {
   249  	if w.device == nil {
   250  		return wgcfg.Stats{}, errors.New("device is not started")
   251  	}
   252  	deviceState, err := userspace.ParseUserspaceDevice(w.device.IpcGetOperation)
   253  	if err != nil {
   254  		return wgcfg.Stats{}, errors.Wrap(err, "could not parse userspace wg device state")
   255  	}
   256  	stats, err := userspace.ParseDevicePeerStats(deviceState)
   257  	if err != nil {
   258  		return wgcfg.Stats{}, errors.Wrap(err, "could not get userspace wg peer stats")
   259  	}
   260  	return stats, nil
   261  }
   262  
   263  func (w *wireguardDeviceImpl) applyConfig(devApi *device.Device, privateKey string, config wireguard.ServiceConfig) error {
   264  	deviceConfig := wgcfg.DeviceConfig{
   265  		PrivateKey: privateKey,
   266  		ListenPort: config.LocalPort,
   267  		Peer: wgcfg.Peer{
   268  			Endpoint:               &config.Provider.Endpoint,
   269  			PublicKey:              config.Provider.PublicKey,
   270  			KeepAlivePeriodSeconds: 18,
   271  			// All traffic through this peer (unfortunately 0.0.0.0/0 didn't work as it was treated as ipv6)
   272  			AllowedIPs: []string{"0.0.0.0/1", "128.0.0.0/1"},
   273  		},
   274  		ReplacePeers: true,
   275  	}
   276  
   277  	if err := devApi.IpcSetOperation(bufio.NewReader(strings.NewReader(deviceConfig.Encode()))); err != nil {
   278  		return fmt.Errorf("could not complete ipc operation: %w", err)
   279  	}
   280  	return nil
   281  }
   282  
   283  func (w *wireguardDeviceImpl) newTunnDevice(wgTunnSetup WireguardTunnelSetup, config wireguard.ServiceConfig, dns connection.DNSOption) (tun.Device, error) {
   284  	consumerIP := config.Consumer.IPAddress
   285  	prefixLen, _ := consumerIP.Mask.Size()
   286  	wgTunnSetup.NewTunnel()
   287  	wgTunnSetup.SetSessionName("wg-tun-session")
   288  	wgTunnSetup.AddTunnelAddress(consumerIP.IP.String(), prefixLen)
   289  	wgTunnSetup.SetMTU(androidTunMtu)
   290  	wgTunnSetup.SetBlocking(true)
   291  
   292  	dnsIPs, err := dns.ResolveIPs(config.Consumer.DNSIPs)
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  	for _, dnsIP := range dnsIPs {
   297  		wgTunnSetup.AddDNS(dnsIP)
   298  	}
   299  
   300  	// Route all traffic through tunnel
   301  	wgTunnSetup.AddRoute("0.0.0.0", 1)
   302  	wgTunnSetup.AddRoute("128.0.0.0", 1)
   303  	wgTunnSetup.AddRoute("::", 1)
   304  	wgTunnSetup.AddRoute("8000::", 1)
   305  
   306  	fd, err := wgTunnSetup.Establish()
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  	log.Info().Msgf("Tun value is: %d", fd)
   311  	tunDevice, err := newDeviceFromFd(fd)
   312  	if err == nil {
   313  		// non-fatal
   314  		name, nameErr := tunDevice.Name()
   315  		log.Info().Err(nameErr).Msg("Name value: " + name)
   316  	}
   317  
   318  	return tunDevice, err
   319  }