github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/netip"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/amnezia-vpn/amneziawg-go/ipc"
    22  )
    23  
    24  type IPCError struct {
    25  	code int64 // error code
    26  	err  error // underlying/wrapped error
    27  }
    28  
    29  func (s IPCError) Error() string {
    30  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    31  }
    32  
    33  func (s IPCError) Unwrap() error {
    34  	return s.err
    35  }
    36  
    37  func (s IPCError) ErrorCode() int64 {
    38  	return s.code
    39  }
    40  
    41  func ipcErrorf(code int64, msg string, args ...any) *IPCError {
    42  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    43  }
    44  
    45  var byteBufferPool = &sync.Pool{
    46  	New: func() any { return new(bytes.Buffer) },
    47  }
    48  
    49  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    50  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    51  func (device *Device) IpcGetOperation(w io.Writer) error {
    52  	device.ipcMutex.RLock()
    53  	defer device.ipcMutex.RUnlock()
    54  
    55  	buf := byteBufferPool.Get().(*bytes.Buffer)
    56  	buf.Reset()
    57  	defer byteBufferPool.Put(buf)
    58  	sendf := func(format string, args ...any) {
    59  		fmt.Fprintf(buf, format, args...)
    60  		buf.WriteByte('\n')
    61  	}
    62  	keyf := func(prefix string, key *[32]byte) {
    63  		buf.Grow(len(key)*2 + 2 + len(prefix))
    64  		buf.WriteString(prefix)
    65  		buf.WriteByte('=')
    66  		const hex = "0123456789abcdef"
    67  		for i := 0; i < len(key); i++ {
    68  			buf.WriteByte(hex[key[i]>>4])
    69  			buf.WriteByte(hex[key[i]&0xf])
    70  		}
    71  		buf.WriteByte('\n')
    72  	}
    73  
    74  	func() {
    75  		// lock required resources
    76  
    77  		device.net.RLock()
    78  		defer device.net.RUnlock()
    79  
    80  		device.staticIdentity.RLock()
    81  		defer device.staticIdentity.RUnlock()
    82  
    83  		device.peers.RLock()
    84  		defer device.peers.RUnlock()
    85  
    86  		// serialize device related values
    87  
    88  		if !device.staticIdentity.privateKey.IsZero() {
    89  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
    90  		}
    91  
    92  		if device.net.port != 0 {
    93  			sendf("listen_port=%d", device.net.port)
    94  		}
    95  
    96  		if device.net.fwmark != 0 {
    97  			sendf("fwmark=%d", device.net.fwmark)
    98  		}
    99  
   100  		if device.isAdvancedSecurityOn() {
   101  			if device.aSecCfg.junkPacketCount != 0 {
   102  				sendf("jc=%d", device.aSecCfg.junkPacketCount)
   103  			}
   104  			if device.aSecCfg.junkPacketMinSize != 0 {
   105  				sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
   106  			}
   107  			if device.aSecCfg.junkPacketMaxSize != 0 {
   108  				sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
   109  			}
   110  			if device.aSecCfg.initPacketJunkSize != 0 {
   111  				sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
   112  			}
   113  			if device.aSecCfg.responsePacketJunkSize != 0 {
   114  				sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
   115  			}
   116  			if device.aSecCfg.initPacketMagicHeader != 0 {
   117  				sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
   118  			}
   119  			if device.aSecCfg.responsePacketMagicHeader != 0 {
   120  				sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
   121  			}
   122  			if device.aSecCfg.underloadPacketMagicHeader != 0 {
   123  				sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
   124  			}
   125  			if device.aSecCfg.transportPacketMagicHeader != 0 {
   126  				sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
   127  			}
   128  		}
   129  
   130  		for _, peer := range device.peers.keyMap {
   131  			// Serialize peer state.
   132  			peer.handshake.mutex.RLock()
   133  			keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
   134  			keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   135  			peer.handshake.mutex.RUnlock()
   136  			sendf("protocol_version=1")
   137  			peer.endpoint.Lock()
   138  			if peer.endpoint.val != nil {
   139  				sendf("endpoint=%s", peer.endpoint.val.DstToString())
   140  			}
   141  			peer.endpoint.Unlock()
   142  
   143  			nano := peer.lastHandshakeNano.Load()
   144  			secs := nano / time.Second.Nanoseconds()
   145  			nano %= time.Second.Nanoseconds()
   146  
   147  			sendf("last_handshake_time_sec=%d", secs)
   148  			sendf("last_handshake_time_nsec=%d", nano)
   149  			sendf("tx_bytes=%d", peer.txBytes.Load())
   150  			sendf("rx_bytes=%d", peer.rxBytes.Load())
   151  			sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
   152  
   153  			device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   154  				sendf("allowed_ip=%s", prefix.String())
   155  				return true
   156  			})
   157  		}
   158  	}()
   159  
   160  	// send lines (does not require resource locks)
   161  	if _, err := w.Write(buf.Bytes()); err != nil {
   162  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   169  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   170  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   171  	device.ipcMutex.Lock()
   172  	defer device.ipcMutex.Unlock()
   173  
   174  	defer func() {
   175  		if err != nil {
   176  			device.log.Errorf("%v", err)
   177  		}
   178  	}()
   179  
   180  	peer := new(ipcSetPeer)
   181  	deviceConfig := true
   182  
   183  	tempASecCfg := aSecCfgType{}
   184  	scanner := bufio.NewScanner(r)
   185  	for scanner.Scan() {
   186  		line := scanner.Text()
   187  		if line == "" {
   188  			// Blank line means terminate operation.
   189  			err := device.handlePostConfig(&tempASecCfg)
   190  			if err != nil {
   191  				return err
   192  			}
   193  			peer.handlePostConfig()
   194  			return nil
   195  		}
   196  		key, value, ok := strings.Cut(line, "=")
   197  		if !ok {
   198  			return ipcErrorf(
   199  				ipc.IpcErrorProtocol,
   200  				"failed to parse line %q",
   201  				line,
   202  			)
   203  		}
   204  
   205  		if key == "public_key" {
   206  			if deviceConfig {
   207  				deviceConfig = false
   208  			}
   209  			peer.handlePostConfig()
   210  			// Load/create the peer we are now configuring.
   211  			err := device.handlePublicKeyLine(peer, value)
   212  			if err != nil {
   213  				return err
   214  			}
   215  			continue
   216  		}
   217  
   218  		var err error
   219  		if deviceConfig {
   220  			err = device.handleDeviceLine(key, value, &tempASecCfg)
   221  		} else {
   222  			err = device.handlePeerLine(peer, key, value)
   223  		}
   224  		if err != nil {
   225  			return err
   226  		}
   227  	}
   228  	err = device.handlePostConfig(&tempASecCfg)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	peer.handlePostConfig()
   233  
   234  	if err := scanner.Err(); err != nil {
   235  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   236  	}
   237  	return nil
   238  }
   239  
   240  func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
   241  	switch key {
   242  	case "private_key":
   243  		var sk NoisePrivateKey
   244  		err := sk.FromMaybeZeroHex(value)
   245  		if err != nil {
   246  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   247  		}
   248  		device.log.Verbosef("UAPI: Updating private key")
   249  		device.SetPrivateKey(sk)
   250  
   251  	case "listen_port":
   252  		port, err := strconv.ParseUint(value, 10, 16)
   253  		if err != nil {
   254  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   255  		}
   256  
   257  		// update port and rebind
   258  		device.log.Verbosef("UAPI: Updating listen port")
   259  
   260  		device.net.Lock()
   261  		device.net.port = uint16(port)
   262  		device.net.Unlock()
   263  
   264  		if err := device.BindUpdate(); err != nil {
   265  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   266  		}
   267  
   268  	case "fwmark":
   269  		mark, err := strconv.ParseUint(value, 10, 32)
   270  		if err != nil {
   271  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   272  		}
   273  
   274  		device.log.Verbosef("UAPI: Updating fwmark")
   275  		if err := device.BindSetMark(uint32(mark)); err != nil {
   276  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   277  		}
   278  
   279  	case "replace_peers":
   280  		if value != "true" {
   281  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   282  		}
   283  		device.log.Verbosef("UAPI: Removing all peers")
   284  		device.RemoveAllPeers()
   285  
   286  	case "jc":
   287  		junkPacketCount, err := strconv.Atoi(value)
   288  		if err != nil {
   289  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
   290  		}
   291  		device.log.Verbosef("UAPI: Updating junk_packet_count")
   292  		tempASecCfg.junkPacketCount = junkPacketCount
   293  		tempASecCfg.isSet = true
   294  
   295  	case "jmin":
   296  		junkPacketMinSize, err := strconv.Atoi(value)
   297  		if err != nil {
   298  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
   299  		}
   300  		device.log.Verbosef("UAPI: Updating junk_packet_min_size")
   301  		tempASecCfg.junkPacketMinSize = junkPacketMinSize
   302  		tempASecCfg.isSet = true
   303  
   304  	case "jmax":
   305  		junkPacketMaxSize, err := strconv.Atoi(value)
   306  		if err != nil {
   307  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
   308  		}
   309  		device.log.Verbosef("UAPI: Updating junk_packet_max_size")
   310  		tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
   311  		tempASecCfg.isSet = true
   312  
   313  	case "s1":
   314  		initPacketJunkSize, err := strconv.Atoi(value)
   315  		if err != nil {
   316  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
   317  		}
   318  		device.log.Verbosef("UAPI: Updating init_packet_junk_size")
   319  		tempASecCfg.initPacketJunkSize = initPacketJunkSize
   320  		tempASecCfg.isSet = true
   321  
   322  	case "s2":
   323  		responsePacketJunkSize, err := strconv.Atoi(value)
   324  		if err != nil {
   325  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
   326  		}
   327  		device.log.Verbosef("UAPI: Updating response_packet_junk_size")
   328  		tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
   329  		tempASecCfg.isSet = true
   330  
   331  	case "h1":
   332  		initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
   333  		if err != nil {
   334  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
   335  		}
   336  		tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
   337  		tempASecCfg.isSet = true
   338  
   339  	case "h2":
   340  		responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
   341  		if err != nil {
   342  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
   343  		}
   344  		tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
   345  		tempASecCfg.isSet = true
   346  
   347  	case "h3":
   348  		underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
   349  		if err != nil {
   350  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
   351  		}
   352  		tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
   353  		tempASecCfg.isSet = true
   354  
   355  	case "h4":
   356  		transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
   357  		if err != nil {
   358  			return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
   359  		}
   360  		tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
   361  		tempASecCfg.isSet = true
   362  
   363  	default:
   364  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   365  	}
   366  
   367  	return nil
   368  }
   369  
   370  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   371  type ipcSetPeer struct {
   372  	*Peer        // Peer is the current peer being operated on
   373  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   374  	created bool // new reports whether this is a newly created peer
   375  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   376  }
   377  
   378  func (peer *ipcSetPeer) handlePostConfig() {
   379  	if peer.Peer == nil || peer.dummy {
   380  		return
   381  	}
   382  	if peer.created {
   383  		peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
   384  	}
   385  	if peer.device.isUp() {
   386  		peer.Start()
   387  		if peer.pkaOn {
   388  			peer.SendKeepalive()
   389  		}
   390  		peer.SendStagedPackets()
   391  	}
   392  }
   393  
   394  func (device *Device) handlePublicKeyLine(
   395  	peer *ipcSetPeer,
   396  	value string,
   397  ) error {
   398  	// Load/create the peer we are configuring.
   399  	var publicKey NoisePublicKey
   400  	err := publicKey.FromHex(value)
   401  	if err != nil {
   402  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   403  	}
   404  
   405  	// Ignore peer with the same public key as this device.
   406  	device.staticIdentity.RLock()
   407  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   408  	device.staticIdentity.RUnlock()
   409  
   410  	if peer.dummy {
   411  		peer.Peer = &Peer{}
   412  	} else {
   413  		peer.Peer = device.LookupPeer(publicKey)
   414  	}
   415  
   416  	peer.created = peer.Peer == nil
   417  	if peer.created {
   418  		peer.Peer, err = device.NewPeer(publicKey)
   419  		if err != nil {
   420  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   421  		}
   422  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   423  	}
   424  	return nil
   425  }
   426  
   427  func (device *Device) handlePeerLine(
   428  	peer *ipcSetPeer,
   429  	key, value string,
   430  ) error {
   431  	switch key {
   432  	case "update_only":
   433  		// allow disabling of creation
   434  		if value != "true" {
   435  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   436  		}
   437  		if peer.created && !peer.dummy {
   438  			device.RemovePeer(peer.handshake.remoteStatic)
   439  			peer.Peer = &Peer{}
   440  			peer.dummy = true
   441  		}
   442  
   443  	case "remove":
   444  		// remove currently selected peer from device
   445  		if value != "true" {
   446  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   447  		}
   448  		if !peer.dummy {
   449  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   450  			device.RemovePeer(peer.handshake.remoteStatic)
   451  		}
   452  		peer.Peer = &Peer{}
   453  		peer.dummy = true
   454  
   455  	case "preshared_key":
   456  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   457  
   458  		peer.handshake.mutex.Lock()
   459  		err := peer.handshake.presharedKey.FromHex(value)
   460  		peer.handshake.mutex.Unlock()
   461  
   462  		if err != nil {
   463  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   464  		}
   465  
   466  	case "endpoint":
   467  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   468  		endpoint, err := device.net.bind.ParseEndpoint(value)
   469  		if err != nil {
   470  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   471  		}
   472  		peer.endpoint.Lock()
   473  		defer peer.endpoint.Unlock()
   474  		peer.endpoint.val = endpoint
   475  
   476  	case "persistent_keepalive_interval":
   477  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   478  
   479  		secs, err := strconv.ParseUint(value, 10, 16)
   480  		if err != nil {
   481  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   482  		}
   483  
   484  		old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
   485  
   486  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   487  		peer.pkaOn = old == 0 && secs != 0
   488  
   489  	case "replace_allowed_ips":
   490  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   491  		if value != "true" {
   492  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   493  		}
   494  		if peer.dummy {
   495  			return nil
   496  		}
   497  		device.allowedips.RemoveByPeer(peer.Peer)
   498  
   499  	case "allowed_ip":
   500  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   501  		prefix, err := netip.ParsePrefix(value)
   502  		if err != nil {
   503  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   504  		}
   505  		if peer.dummy {
   506  			return nil
   507  		}
   508  		device.allowedips.Insert(prefix, peer.Peer)
   509  
   510  	case "protocol_version":
   511  		if value != "1" {
   512  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   513  		}
   514  
   515  	default:
   516  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   517  	}
   518  
   519  	return nil
   520  }
   521  
   522  func (device *Device) IpcGet() (string, error) {
   523  	buf := new(strings.Builder)
   524  	if err := device.IpcGetOperation(buf); err != nil {
   525  		return "", err
   526  	}
   527  	return buf.String(), nil
   528  }
   529  
   530  func (device *Device) IpcSet(uapiConf string) error {
   531  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   532  }
   533  
   534  func (device *Device) IpcHandle(socket net.Conn) {
   535  	defer socket.Close()
   536  
   537  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   538  		reader := bufio.NewReader(s)
   539  		writer := bufio.NewWriter(s)
   540  		return bufio.NewReadWriter(reader, writer)
   541  	}(socket)
   542  
   543  	for {
   544  		op, err := buffered.ReadString('\n')
   545  		if err != nil {
   546  			return
   547  		}
   548  
   549  		// handle operation
   550  		switch op {
   551  		case "set=1\n":
   552  			err = device.IpcSetOperation(buffered.Reader)
   553  		case "get=1\n":
   554  			var nextByte byte
   555  			nextByte, err = buffered.ReadByte()
   556  			if err != nil {
   557  				return
   558  			}
   559  			if nextByte != '\n' {
   560  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   561  				break
   562  			}
   563  			err = device.IpcGetOperation(buffered.Writer)
   564  		default:
   565  			device.log.Errorf("invalid UAPI operation: %v", op)
   566  			return
   567  		}
   568  
   569  		// write status
   570  		var status *IPCError
   571  		if err != nil && !errors.As(err, &status) {
   572  			// shouldn't happen
   573  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   574  		}
   575  		if status != nil {
   576  			device.log.Errorf("%v", status)
   577  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   578  		} else {
   579  			fmt.Fprintf(buffered, "errno=0\n\n")
   580  		}
   581  		buffered.Flush()
   582  	}
   583  }