github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/netip"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/sagernet/wireguard-go/ipc"
    22  )
    23  
    24  type IPCError struct {
    25  	code int64 // error code
    26  	err  error // underlying/wrapped error
    27  }
    28  
    29  func (s IPCError) Error() string {
    30  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    31  }
    32  
    33  func (s IPCError) Unwrap() error {
    34  	return s.err
    35  }
    36  
    37  func (s IPCError) ErrorCode() int64 {
    38  	return s.code
    39  }
    40  
    41  func ipcErrorf(code int64, msg string, args ...any) *IPCError {
    42  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    43  }
    44  
    45  var byteBufferPool = &sync.Pool{
    46  	New: func() any { return new(bytes.Buffer) },
    47  }
    48  
    49  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    50  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    51  func (device *Device) IpcGetOperation(w io.Writer) error {
    52  	device.ipcMutex.RLock()
    53  	defer device.ipcMutex.RUnlock()
    54  
    55  	buf := byteBufferPool.Get().(*bytes.Buffer)
    56  	buf.Reset()
    57  	defer byteBufferPool.Put(buf)
    58  	sendf := func(format string, args ...any) {
    59  		fmt.Fprintf(buf, format, args...)
    60  		buf.WriteByte('\n')
    61  	}
    62  	keyf := func(prefix string, key *[32]byte) {
    63  		buf.Grow(len(key)*2 + 2 + len(prefix))
    64  		buf.WriteString(prefix)
    65  		buf.WriteByte('=')
    66  		const hex = "0123456789abcdef"
    67  		for i := 0; i < len(key); i++ {
    68  			buf.WriteByte(hex[key[i]>>4])
    69  			buf.WriteByte(hex[key[i]&0xf])
    70  		}
    71  		buf.WriteByte('\n')
    72  	}
    73  
    74  	func() {
    75  		// lock required resources
    76  
    77  		device.net.RLock()
    78  		defer device.net.RUnlock()
    79  
    80  		device.staticIdentity.RLock()
    81  		defer device.staticIdentity.RUnlock()
    82  
    83  		device.peers.RLock()
    84  		defer device.peers.RUnlock()
    85  
    86  		// serialize device related values
    87  
    88  		if !device.staticIdentity.privateKey.IsZero() {
    89  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
    90  		}
    91  
    92  		if device.net.port != 0 {
    93  			sendf("listen_port=%d", device.net.port)
    94  		}
    95  
    96  		if device.net.fwmark != 0 {
    97  			sendf("fwmark=%d", device.net.fwmark)
    98  		}
    99  
   100  		for _, peer := range device.peers.keyMap {
   101  			// Serialize peer state.
   102  			// Do the work in an anonymous function so that we can use defer.
   103  			func() {
   104  				peer.RLock()
   105  				defer peer.RUnlock()
   106  
   107  				keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
   108  				keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   109  				sendf("protocol_version=1")
   110  				if peer.endpoint != nil {
   111  					sendf("endpoint=%s", peer.endpoint.DstToString())
   112  				}
   113  
   114  				nano := peer.lastHandshakeNano.Load()
   115  				secs := nano / time.Second.Nanoseconds()
   116  				nano %= time.Second.Nanoseconds()
   117  
   118  				sendf("last_handshake_time_sec=%d", secs)
   119  				sendf("last_handshake_time_nsec=%d", nano)
   120  				sendf("tx_bytes=%d", peer.txBytes.Load())
   121  				sendf("rx_bytes=%d", peer.rxBytes.Load())
   122  				sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
   123  
   124  				device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   125  					sendf("allowed_ip=%s", prefix.String())
   126  					return true
   127  				})
   128  			}()
   129  		}
   130  	}()
   131  
   132  	// send lines (does not require resource locks)
   133  	if _, err := w.Write(buf.Bytes()); err != nil {
   134  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   141  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   142  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   143  	device.ipcMutex.Lock()
   144  	defer device.ipcMutex.Unlock()
   145  
   146  	defer func() {
   147  		if err != nil {
   148  			device.log.Errorf("%v", err)
   149  		}
   150  	}()
   151  
   152  	peer := new(ipcSetPeer)
   153  	deviceConfig := true
   154  
   155  	scanner := bufio.NewScanner(r)
   156  	for scanner.Scan() {
   157  		line := scanner.Text()
   158  		if line == "" {
   159  			// Blank line means terminate operation.
   160  			peer.handlePostConfig()
   161  			return nil
   162  		}
   163  		key, value, ok := strings.Cut(line, "=")
   164  		if !ok {
   165  			return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
   166  		}
   167  
   168  		if key == "public_key" {
   169  			if deviceConfig {
   170  				deviceConfig = false
   171  			}
   172  			peer.handlePostConfig()
   173  			// Load/create the peer we are now configuring.
   174  			err := device.handlePublicKeyLine(peer, value)
   175  			if err != nil {
   176  				return err
   177  			}
   178  			continue
   179  		}
   180  
   181  		var err error
   182  		if deviceConfig {
   183  			err = device.handleDeviceLine(key, value)
   184  		} else {
   185  			err = device.handlePeerLine(peer, key, value)
   186  		}
   187  		if err != nil {
   188  			return err
   189  		}
   190  	}
   191  	peer.handlePostConfig()
   192  
   193  	if err := scanner.Err(); err != nil {
   194  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   195  	}
   196  	return nil
   197  }
   198  
   199  func (device *Device) handleDeviceLine(key, value string) error {
   200  	switch key {
   201  	case "private_key":
   202  		var sk NoisePrivateKey
   203  		err := sk.FromMaybeZeroHex(value)
   204  		if err != nil {
   205  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   206  		}
   207  		device.log.Verbosef("UAPI: Updating private key")
   208  		device.SetPrivateKey(sk)
   209  
   210  	case "listen_port":
   211  		port, err := strconv.ParseUint(value, 10, 16)
   212  		if err != nil {
   213  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   214  		}
   215  
   216  		// update port and rebind
   217  		device.log.Verbosef("UAPI: Updating listen port")
   218  
   219  		device.net.Lock()
   220  		device.net.port = uint16(port)
   221  		device.net.Unlock()
   222  
   223  		if err := device.BindUpdate(); err != nil {
   224  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   225  		}
   226  
   227  	case "fwmark":
   228  		mark, err := strconv.ParseUint(value, 10, 32)
   229  		if err != nil {
   230  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   231  		}
   232  
   233  		device.log.Verbosef("UAPI: Updating fwmark")
   234  		if err := device.BindSetMark(uint32(mark)); err != nil {
   235  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   236  		}
   237  
   238  	case "replace_peers":
   239  		if value != "true" {
   240  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   241  		}
   242  		device.log.Verbosef("UAPI: Removing all peers")
   243  		device.RemoveAllPeers()
   244  
   245  	default:
   246  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   247  	}
   248  
   249  	return nil
   250  }
   251  
   252  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   253  type ipcSetPeer struct {
   254  	*Peer        // Peer is the current peer being operated on
   255  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   256  	created bool // new reports whether this is a newly created peer
   257  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   258  }
   259  
   260  func (peer *ipcSetPeer) handlePostConfig() {
   261  	if peer.Peer == nil || peer.dummy {
   262  		return
   263  	}
   264  	if peer.created {
   265  		peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
   266  	}
   267  	if peer.device.isUp() {
   268  		peer.Start()
   269  		if peer.pkaOn {
   270  			peer.SendKeepalive()
   271  		}
   272  		peer.SendStagedPackets()
   273  	}
   274  }
   275  
   276  func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
   277  	// Load/create the peer we are configuring.
   278  	var publicKey NoisePublicKey
   279  	err := publicKey.FromHex(value)
   280  	if err != nil {
   281  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   282  	}
   283  
   284  	// Ignore peer with the same public key as this device.
   285  	device.staticIdentity.RLock()
   286  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   287  	device.staticIdentity.RUnlock()
   288  
   289  	if peer.dummy {
   290  		peer.Peer = &Peer{}
   291  	} else {
   292  		peer.Peer = device.LookupPeer(publicKey)
   293  	}
   294  
   295  	peer.created = peer.Peer == nil
   296  	if peer.created {
   297  		peer.Peer, err = device.NewPeer(publicKey)
   298  		if err != nil {
   299  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   300  		}
   301  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   302  	}
   303  	return nil
   304  }
   305  
   306  func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
   307  	switch key {
   308  	case "update_only":
   309  		// allow disabling of creation
   310  		if value != "true" {
   311  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   312  		}
   313  		if peer.created && !peer.dummy {
   314  			device.RemovePeer(peer.handshake.remoteStatic)
   315  			peer.Peer = &Peer{}
   316  			peer.dummy = true
   317  		}
   318  
   319  	case "remove":
   320  		// remove currently selected peer from device
   321  		if value != "true" {
   322  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   323  		}
   324  		if !peer.dummy {
   325  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   326  			device.RemovePeer(peer.handshake.remoteStatic)
   327  		}
   328  		peer.Peer = &Peer{}
   329  		peer.dummy = true
   330  
   331  	case "preshared_key":
   332  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   333  
   334  		peer.handshake.mutex.Lock()
   335  		err := peer.handshake.presharedKey.FromHex(value)
   336  		peer.handshake.mutex.Unlock()
   337  
   338  		if err != nil {
   339  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   340  		}
   341  
   342  	case "endpoint":
   343  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   344  		endpoint, err := device.net.bind.ParseEndpoint(value)
   345  		if err != nil {
   346  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   347  		}
   348  		peer.Lock()
   349  		defer peer.Unlock()
   350  		peer.endpoint = endpoint
   351  
   352  	case "persistent_keepalive_interval":
   353  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   354  
   355  		secs, err := strconv.ParseUint(value, 10, 16)
   356  		if err != nil {
   357  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   358  		}
   359  
   360  		old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
   361  
   362  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   363  		peer.pkaOn = old == 0 && secs != 0
   364  
   365  	case "replace_allowed_ips":
   366  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   367  		if value != "true" {
   368  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   369  		}
   370  		if peer.dummy {
   371  			return nil
   372  		}
   373  		device.allowedips.RemoveByPeer(peer.Peer)
   374  
   375  	case "allowed_ip":
   376  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   377  		prefix, err := netip.ParsePrefix(value)
   378  		if err != nil {
   379  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   380  		}
   381  		if peer.dummy {
   382  			return nil
   383  		}
   384  		device.allowedips.Insert(prefix, peer.Peer)
   385  
   386  	case "protocol_version":
   387  		if value != "1" {
   388  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   389  		}
   390  
   391  	default:
   392  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   393  	}
   394  
   395  	return nil
   396  }
   397  
   398  func (device *Device) IpcGet() (string, error) {
   399  	buf := new(strings.Builder)
   400  	if err := device.IpcGetOperation(buf); err != nil {
   401  		return "", err
   402  	}
   403  	return buf.String(), nil
   404  }
   405  
   406  func (device *Device) IpcSet(uapiConf string) error {
   407  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   408  }
   409  
   410  func (device *Device) IpcHandle(socket net.Conn) {
   411  	defer socket.Close()
   412  
   413  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   414  		reader := bufio.NewReader(s)
   415  		writer := bufio.NewWriter(s)
   416  		return bufio.NewReadWriter(reader, writer)
   417  	}(socket)
   418  
   419  	for {
   420  		op, err := buffered.ReadString('\n')
   421  		if err != nil {
   422  			return
   423  		}
   424  
   425  		// handle operation
   426  		switch op {
   427  		case "set=1\n":
   428  			err = device.IpcSetOperation(buffered.Reader)
   429  		case "get=1\n":
   430  			var nextByte byte
   431  			nextByte, err = buffered.ReadByte()
   432  			if err != nil {
   433  				return
   434  			}
   435  			if nextByte != '\n' {
   436  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   437  				break
   438  			}
   439  			err = device.IpcGetOperation(buffered.Writer)
   440  		default:
   441  			device.log.Errorf("invalid UAPI operation: %v", op)
   442  			return
   443  		}
   444  
   445  		// write status
   446  		var status *IPCError
   447  		if err != nil && !errors.As(err, &status) {
   448  			// shouldn't happen
   449  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   450  		}
   451  		if status != nil {
   452  			device.log.Errorf("%v", status)
   453  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   454  		} else {
   455  			fmt.Fprintf(buffered, "errno=0\n\n")
   456  		}
   457  		buffered.Flush()
   458  	}
   459  }