github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/netip"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/cawidtu/notwireguard-go/ipc"
    23  )
    24  
    25  type IPCError struct {
    26  	code int64 // error code
    27  	err  error // underlying/wrapped error
    28  }
    29  
    30  func (s IPCError) Error() string {
    31  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    32  }
    33  
    34  func (s IPCError) Unwrap() error {
    35  	return s.err
    36  }
    37  
    38  func (s IPCError) ErrorCode() int64 {
    39  	return s.code
    40  }
    41  
    42  func ipcErrorf(code int64, msg string, args ...any) *IPCError {
    43  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    44  }
    45  
    46  var byteBufferPool = &sync.Pool{
    47  	New: func() any { return new(bytes.Buffer) },
    48  }
    49  
    50  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    51  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    52  func (device *Device) IpcGetOperation(w io.Writer) error {
    53  	device.ipcMutex.RLock()
    54  	defer device.ipcMutex.RUnlock()
    55  
    56  	buf := byteBufferPool.Get().(*bytes.Buffer)
    57  	buf.Reset()
    58  	defer byteBufferPool.Put(buf)
    59  	sendf := func(format string, args ...any) {
    60  		fmt.Fprintf(buf, format, args...)
    61  		buf.WriteByte('\n')
    62  	}
    63  	keyf := func(prefix string, key *[32]byte) {
    64  		buf.Grow(len(key)*2 + 2 + len(prefix))
    65  		buf.WriteString(prefix)
    66  		buf.WriteByte('=')
    67  		const hex = "0123456789abcdef"
    68  		for i := 0; i < len(key); i++ {
    69  			buf.WriteByte(hex[key[i]>>4])
    70  			buf.WriteByte(hex[key[i]&0xf])
    71  		}
    72  		buf.WriteByte('\n')
    73  	}
    74  
    75  	func() {
    76  		// lock required resources
    77  
    78  		device.net.RLock()
    79  		defer device.net.RUnlock()
    80  
    81  		device.staticIdentity.RLock()
    82  		defer device.staticIdentity.RUnlock()
    83  
    84  		device.peers.RLock()
    85  		defer device.peers.RUnlock()
    86  
    87  		// serialize device related values
    88  
    89  		if !device.staticIdentity.privateKey.IsZero() {
    90  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
    91  		}
    92  
    93  		if device.net.port != 0 {
    94  			sendf("listen_port=%d", device.net.port)
    95  		}
    96  
    97  		if device.net.fwmark != 0 {
    98  			sendf("fwmark=%d", device.net.fwmark)
    99  		}
   100  
   101  		for _, peer := range device.peers.keyMap {
   102  			// Serialize peer state.
   103  			// Do the work in an anonymous function so that we can use defer.
   104  			func() {
   105  				peer.RLock()
   106  				defer peer.RUnlock()
   107  
   108  				keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
   109  				keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   110  				sendf("protocol_version=1")
   111  				if peer.endpoint != nil {
   112  					sendf("endpoint=%s", peer.endpoint.DstToString())
   113  				}
   114  
   115  				nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
   116  				secs := nano / time.Second.Nanoseconds()
   117  				nano %= time.Second.Nanoseconds()
   118  
   119  				sendf("last_handshake_time_sec=%d", secs)
   120  				sendf("last_handshake_time_nsec=%d", nano)
   121  				sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
   122  				sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
   123  				sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
   124  
   125  				device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   126  					sendf("allowed_ip=%s", prefix.String())
   127  					return true
   128  				})
   129  			}()
   130  		}
   131  	}()
   132  
   133  	// send lines (does not require resource locks)
   134  	if _, err := w.Write(buf.Bytes()); err != nil {
   135  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   142  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   143  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   144  	device.ipcMutex.Lock()
   145  	defer device.ipcMutex.Unlock()
   146  
   147  	defer func() {
   148  		if err != nil {
   149  			device.log.Errorf("%v", err)
   150  		}
   151  	}()
   152  
   153  	peer := new(ipcSetPeer)
   154  	deviceConfig := true
   155  
   156  	scanner := bufio.NewScanner(r)
   157  	for scanner.Scan() {
   158  		line := scanner.Text()
   159  		if line == "" {
   160  			// Blank line means terminate operation.
   161  			peer.handlePostConfig()
   162  			return nil
   163  		}
   164  		key, value, ok := strings.Cut(line, "=")
   165  		if !ok {
   166  			return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
   167  		}
   168  
   169  		if key == "public_key" {
   170  			if deviceConfig {
   171  				deviceConfig = false
   172  			}
   173  			peer.handlePostConfig()
   174  			// Load/create the peer we are now configuring.
   175  			err := device.handlePublicKeyLine(peer, value)
   176  			if err != nil {
   177  				return err
   178  			}
   179  			continue
   180  		}
   181  
   182  		var err error
   183  		if deviceConfig {
   184  			err = device.handleDeviceLine(key, value)
   185  		} else {
   186  			err = device.handlePeerLine(peer, key, value)
   187  		}
   188  		if err != nil {
   189  			return err
   190  		}
   191  	}
   192  	peer.handlePostConfig()
   193  
   194  	if err := scanner.Err(); err != nil {
   195  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   196  	}
   197  	return nil
   198  }
   199  
   200  func (device *Device) handleDeviceLine(key, value string) error {
   201  	switch key {
   202  	case "private_key":
   203  		var sk NoisePrivateKey
   204  		err := sk.FromMaybeZeroHex(value)
   205  		if err != nil {
   206  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   207  		}
   208  		device.log.Verbosef("UAPI: Updating private key")
   209  		device.SetPrivateKey(sk)
   210  
   211  	case "listen_port":
   212  		port, err := strconv.ParseUint(value, 10, 16)
   213  		if err != nil {
   214  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   215  		}
   216  
   217  		// update port and rebind
   218  		device.log.Verbosef("UAPI: Updating listen port")
   219  
   220  		device.net.Lock()
   221  		device.net.port = uint16(port)
   222  		device.net.Unlock()
   223  
   224  		if err := device.BindUpdate(); err != nil {
   225  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   226  		}
   227  
   228  	case "fwmark":
   229  		mark, err := strconv.ParseUint(value, 10, 32)
   230  		if err != nil {
   231  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   232  		}
   233  
   234  		device.log.Verbosef("UAPI: Updating fwmark")
   235  		if err := device.BindSetMark(uint32(mark)); err != nil {
   236  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   237  		}
   238  
   239  	case "replace_peers":
   240  		if value != "true" {
   241  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   242  		}
   243  		device.log.Verbosef("UAPI: Removing all peers")
   244  		device.RemoveAllPeers()
   245  
   246  	default:
   247  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   248  	}
   249  
   250  	return nil
   251  }
   252  
   253  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   254  type ipcSetPeer struct {
   255  	*Peer        // Peer is the current peer being operated on
   256  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   257  	created bool // new reports whether this is a newly created peer
   258  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   259  }
   260  
   261  func (peer *ipcSetPeer) handlePostConfig() {
   262  	if peer.Peer == nil || peer.dummy {
   263  		return
   264  	}
   265  	if peer.created {
   266  		peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
   267  	}
   268  	if peer.device.isUp() {
   269  		peer.Start()
   270  		if peer.pkaOn {
   271  			peer.SendKeepalive()
   272  		}
   273  		peer.SendStagedPackets()
   274  	}
   275  }
   276  
   277  func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
   278  	// Load/create the peer we are configuring.
   279  	var publicKey NoisePublicKey
   280  	err := publicKey.FromHex(value)
   281  	if err != nil {
   282  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   283  	}
   284  
   285  	// Ignore peer with the same public key as this device.
   286  	device.staticIdentity.RLock()
   287  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   288  	device.staticIdentity.RUnlock()
   289  
   290  	if peer.dummy {
   291  		peer.Peer = &Peer{}
   292  	} else {
   293  		peer.Peer = device.LookupPeer(publicKey)
   294  	}
   295  
   296  	peer.created = peer.Peer == nil
   297  	if peer.created {
   298  		peer.Peer, err = device.NewPeer(publicKey)
   299  		if err != nil {
   300  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   301  		}
   302  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   303  	}
   304  	return nil
   305  }
   306  
   307  func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
   308  	switch key {
   309  	case "update_only":
   310  		// allow disabling of creation
   311  		if value != "true" {
   312  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   313  		}
   314  		if peer.created && !peer.dummy {
   315  			device.RemovePeer(peer.handshake.remoteStatic)
   316  			peer.Peer = &Peer{}
   317  			peer.dummy = true
   318  		}
   319  
   320  	case "remove":
   321  		// remove currently selected peer from device
   322  		if value != "true" {
   323  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   324  		}
   325  		if !peer.dummy {
   326  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   327  			device.RemovePeer(peer.handshake.remoteStatic)
   328  		}
   329  		peer.Peer = &Peer{}
   330  		peer.dummy = true
   331  
   332  	case "preshared_key":
   333  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   334  
   335  		peer.handshake.mutex.Lock()
   336  		err := peer.handshake.presharedKey.FromHex(value)
   337  		peer.handshake.mutex.Unlock()
   338  
   339  		if err != nil {
   340  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   341  		}
   342  
   343  	case "endpoint":
   344  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   345  		endpoint, err := device.net.bind.ParseEndpoint(value)
   346  		if err != nil {
   347  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   348  		}
   349  		peer.Lock()
   350  		defer peer.Unlock()
   351  		peer.endpoint = endpoint
   352  
   353  	case "persistent_keepalive_interval":
   354  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   355  
   356  		secs, err := strconv.ParseUint(value, 10, 16)
   357  		if err != nil {
   358  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   359  		}
   360  
   361  		old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
   362  
   363  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   364  		peer.pkaOn = old == 0 && secs != 0
   365  
   366  	case "replace_allowed_ips":
   367  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   368  		if value != "true" {
   369  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   370  		}
   371  		if peer.dummy {
   372  			return nil
   373  		}
   374  		device.allowedips.RemoveByPeer(peer.Peer)
   375  
   376  	case "allowed_ip":
   377  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   378  		prefix, err := netip.ParsePrefix(value)
   379  		if err != nil {
   380  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   381  		}
   382  		if peer.dummy {
   383  			return nil
   384  		}
   385  		device.allowedips.Insert(prefix, peer.Peer)
   386  
   387  	case "protocol_version":
   388  		if value != "1" {
   389  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   390  		}
   391  
   392  	default:
   393  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   394  	}
   395  
   396  	return nil
   397  }
   398  
   399  func (device *Device) IpcGet() (string, error) {
   400  	buf := new(strings.Builder)
   401  	if err := device.IpcGetOperation(buf); err != nil {
   402  		return "", err
   403  	}
   404  	return buf.String(), nil
   405  }
   406  
   407  func (device *Device) IpcSet(uapiConf string) error {
   408  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   409  }
   410  
   411  func (device *Device) IpcHandle(socket net.Conn) {
   412  	defer socket.Close()
   413  
   414  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   415  		reader := bufio.NewReader(s)
   416  		writer := bufio.NewWriter(s)
   417  		return bufio.NewReadWriter(reader, writer)
   418  	}(socket)
   419  
   420  	for {
   421  		op, err := buffered.ReadString('\n')
   422  		if err != nil {
   423  			return
   424  		}
   425  
   426  		// handle operation
   427  		switch op {
   428  		case "set=1\n":
   429  			err = device.IpcSetOperation(buffered.Reader)
   430  		case "get=1\n":
   431  			var nextByte byte
   432  			nextByte, err = buffered.ReadByte()
   433  			if err != nil {
   434  				return
   435  			}
   436  			if nextByte != '\n' {
   437  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   438  				break
   439  			}
   440  			err = device.IpcGetOperation(buffered.Writer)
   441  		default:
   442  			device.log.Errorf("invalid UAPI operation: %v", op)
   443  			return
   444  		}
   445  
   446  		// write status
   447  		var status *IPCError
   448  		if err != nil && !errors.As(err, &status) {
   449  			// shouldn't happen
   450  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   451  		}
   452  		if status != nil {
   453  			device.log.Errorf("%v", status)
   454  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   455  		} else {
   456  			fmt.Fprintf(buffered, "errno=0\n\n")
   457  		}
   458  		buffered.Flush()
   459  	}
   460  }