github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"golang.zx2c4.com/go118/netip"
    22  	"github.com/liloew/wireguard-go/ipc"
    23  )
    24  
    25  type IPCError struct {
    26  	code int64 // error code
    27  	err  error // underlying/wrapped error
    28  }
    29  
    30  func (s IPCError) Error() string {
    31  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    32  }
    33  
    34  func (s IPCError) Unwrap() error {
    35  	return s.err
    36  }
    37  
    38  func (s IPCError) ErrorCode() int64 {
    39  	return s.code
    40  }
    41  
    42  func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
    43  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    44  }
    45  
    46  var byteBufferPool = &sync.Pool{
    47  	New: func() interface{} { return new(bytes.Buffer) },
    48  }
    49  
    50  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    51  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    52  func (device *Device) IpcGetOperation(w io.Writer) error {
    53  	device.ipcMutex.RLock()
    54  	defer device.ipcMutex.RUnlock()
    55  
    56  	buf := byteBufferPool.Get().(*bytes.Buffer)
    57  	buf.Reset()
    58  	defer byteBufferPool.Put(buf)
    59  	sendf := func(format string, args ...interface{}) {
    60  		fmt.Fprintf(buf, format, args...)
    61  		buf.WriteByte('\n')
    62  	}
    63  	keyf := func(prefix string, key *[32]byte) {
    64  		buf.Grow(len(key)*2 + 2 + len(prefix))
    65  		buf.WriteString(prefix)
    66  		buf.WriteByte('=')
    67  		const hex = "0123456789abcdef"
    68  		for i := 0; i < len(key); i++ {
    69  			buf.WriteByte(hex[key[i]>>4])
    70  			buf.WriteByte(hex[key[i]&0xf])
    71  		}
    72  		buf.WriteByte('\n')
    73  	}
    74  
    75  	func() {
    76  		// lock required resources
    77  
    78  		device.net.RLock()
    79  		defer device.net.RUnlock()
    80  
    81  		device.staticIdentity.RLock()
    82  		defer device.staticIdentity.RUnlock()
    83  
    84  		device.peers.RLock()
    85  		defer device.peers.RUnlock()
    86  
    87  		// serialize device related values
    88  
    89  		if !device.staticIdentity.privateKey.IsZero() {
    90  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
    91  		}
    92  
    93  		if device.net.port != 0 {
    94  			sendf("listen_port=%d", device.net.port)
    95  		}
    96  
    97  		if device.net.fwmark != 0 {
    98  			sendf("fwmark=%d", device.net.fwmark)
    99  		}
   100  
   101  		for _, peer := range device.peers.keyMap {
   102  			// Serialize peer state.
   103  			// Do the work in an anonymous function so that we can use defer.
   104  			func() {
   105  				peer.RLock()
   106  				defer peer.RUnlock()
   107  
   108  				keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
   109  				keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   110  				sendf("protocol_version=1")
   111  				if peer.endpoint != nil {
   112  					sendf("endpoint=%s", peer.endpoint.DstToString())
   113  				}
   114  
   115  				nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
   116  				secs := nano / time.Second.Nanoseconds()
   117  				nano %= time.Second.Nanoseconds()
   118  
   119  				sendf("last_handshake_time_sec=%d", secs)
   120  				sendf("last_handshake_time_nsec=%d", nano)
   121  				sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
   122  				sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
   123  				sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
   124  
   125  				device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   126  					sendf("allowed_ip=%s", prefix.String())
   127  					return true
   128  				})
   129  			}()
   130  		}
   131  	}()
   132  
   133  	// send lines (does not require resource locks)
   134  	if _, err := w.Write(buf.Bytes()); err != nil {
   135  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   142  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   143  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   144  	device.ipcMutex.Lock()
   145  	defer device.ipcMutex.Unlock()
   146  
   147  	defer func() {
   148  		if err != nil {
   149  			device.log.Errorf("%v", err)
   150  		}
   151  	}()
   152  
   153  	peer := new(ipcSetPeer)
   154  	deviceConfig := true
   155  
   156  	scanner := bufio.NewScanner(r)
   157  	for scanner.Scan() {
   158  		line := scanner.Text()
   159  		if line == "" {
   160  			// Blank line means terminate operation.
   161  			peer.handlePostConfig()
   162  			return nil
   163  		}
   164  		parts := strings.Split(line, "=")
   165  		if len(parts) != 2 {
   166  			return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
   167  		}
   168  		key := parts[0]
   169  		value := parts[1]
   170  
   171  		if key == "public_key" {
   172  			if deviceConfig {
   173  				deviceConfig = false
   174  			}
   175  			peer.handlePostConfig()
   176  			// Load/create the peer we are now configuring.
   177  			err := device.handlePublicKeyLine(peer, value)
   178  			if err != nil {
   179  				return err
   180  			}
   181  			continue
   182  		}
   183  
   184  		var err error
   185  		if deviceConfig {
   186  			err = device.handleDeviceLine(key, value)
   187  		} else {
   188  			err = device.handlePeerLine(peer, key, value)
   189  		}
   190  		if err != nil {
   191  			return err
   192  		}
   193  	}
   194  	peer.handlePostConfig()
   195  
   196  	if err := scanner.Err(); err != nil {
   197  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   198  	}
   199  	return nil
   200  }
   201  
   202  func (device *Device) handleDeviceLine(key, value string) error {
   203  	switch key {
   204  	case "private_key":
   205  		var sk NoisePrivateKey
   206  		err := sk.FromMaybeZeroHex(value)
   207  		if err != nil {
   208  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   209  		}
   210  		device.log.Verbosef("UAPI: Updating private key")
   211  		device.SetPrivateKey(sk)
   212  
   213  	case "listen_port":
   214  		port, err := strconv.ParseUint(value, 10, 16)
   215  		if err != nil {
   216  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   217  		}
   218  
   219  		// update port and rebind
   220  		device.log.Verbosef("UAPI: Updating listen port")
   221  
   222  		device.net.Lock()
   223  		device.net.port = uint16(port)
   224  		device.net.Unlock()
   225  
   226  		if err := device.BindUpdate(); err != nil {
   227  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   228  		}
   229  
   230  	case "fwmark":
   231  		mark, err := strconv.ParseUint(value, 10, 32)
   232  		if err != nil {
   233  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   234  		}
   235  
   236  		device.log.Verbosef("UAPI: Updating fwmark")
   237  		if err := device.BindSetMark(uint32(mark)); err != nil {
   238  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   239  		}
   240  
   241  	case "replace_peers":
   242  		if value != "true" {
   243  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   244  		}
   245  		device.log.Verbosef("UAPI: Removing all peers")
   246  		device.RemoveAllPeers()
   247  
   248  	default:
   249  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   256  type ipcSetPeer struct {
   257  	*Peer        // Peer is the current peer being operated on
   258  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   259  	created bool // new reports whether this is a newly created peer
   260  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   261  }
   262  
   263  func (peer *ipcSetPeer) handlePostConfig() {
   264  	if peer.Peer == nil || peer.dummy {
   265  		return
   266  	}
   267  	if peer.created {
   268  		peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
   269  	}
   270  	if peer.device.isUp() {
   271  		peer.Start()
   272  		if peer.pkaOn {
   273  			peer.SendKeepalive()
   274  		}
   275  		peer.SendStagedPackets()
   276  	}
   277  }
   278  
   279  func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
   280  	// Load/create the peer we are configuring.
   281  	var publicKey NoisePublicKey
   282  	err := publicKey.FromHex(value)
   283  	if err != nil {
   284  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   285  	}
   286  
   287  	// Ignore peer with the same public key as this device.
   288  	device.staticIdentity.RLock()
   289  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   290  	device.staticIdentity.RUnlock()
   291  
   292  	if peer.dummy {
   293  		peer.Peer = &Peer{}
   294  	} else {
   295  		peer.Peer = device.LookupPeer(publicKey)
   296  	}
   297  
   298  	peer.created = peer.Peer == nil
   299  	if peer.created {
   300  		peer.Peer, err = device.NewPeer(publicKey)
   301  		if err != nil {
   302  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   303  		}
   304  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   305  	}
   306  	return nil
   307  }
   308  
   309  func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
   310  	switch key {
   311  	case "update_only":
   312  		// allow disabling of creation
   313  		if value != "true" {
   314  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   315  		}
   316  		if peer.created && !peer.dummy {
   317  			device.RemovePeer(peer.handshake.remoteStatic)
   318  			peer.Peer = &Peer{}
   319  			peer.dummy = true
   320  		}
   321  
   322  	case "remove":
   323  		// remove currently selected peer from device
   324  		if value != "true" {
   325  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   326  		}
   327  		if !peer.dummy {
   328  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   329  			device.RemovePeer(peer.handshake.remoteStatic)
   330  		}
   331  		peer.Peer = &Peer{}
   332  		peer.dummy = true
   333  
   334  	case "preshared_key":
   335  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   336  
   337  		peer.handshake.mutex.Lock()
   338  		err := peer.handshake.presharedKey.FromHex(value)
   339  		peer.handshake.mutex.Unlock()
   340  
   341  		if err != nil {
   342  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   343  		}
   344  
   345  	case "endpoint":
   346  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   347  		endpoint, err := device.net.bind.ParseEndpoint(value)
   348  		if err != nil {
   349  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   350  		}
   351  		peer.Lock()
   352  		defer peer.Unlock()
   353  		peer.endpoint = endpoint
   354  
   355  	case "persistent_keepalive_interval":
   356  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   357  
   358  		secs, err := strconv.ParseUint(value, 10, 16)
   359  		if err != nil {
   360  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   361  		}
   362  
   363  		old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
   364  
   365  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   366  		peer.pkaOn = old == 0 && secs != 0
   367  
   368  	case "replace_allowed_ips":
   369  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   370  		if value != "true" {
   371  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   372  		}
   373  		if peer.dummy {
   374  			return nil
   375  		}
   376  		device.allowedips.RemoveByPeer(peer.Peer)
   377  
   378  	case "allowed_ip":
   379  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   380  		prefix, err := netip.ParsePrefix(value)
   381  		if err != nil {
   382  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   383  		}
   384  		if peer.dummy {
   385  			return nil
   386  		}
   387  		device.allowedips.Insert(prefix, peer.Peer)
   388  
   389  	case "protocol_version":
   390  		if value != "1" {
   391  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   392  		}
   393  
   394  	default:
   395  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   396  	}
   397  
   398  	return nil
   399  }
   400  
   401  func (device *Device) IpcGet() (string, error) {
   402  	buf := new(strings.Builder)
   403  	if err := device.IpcGetOperation(buf); err != nil {
   404  		return "", err
   405  	}
   406  	return buf.String(), nil
   407  }
   408  
   409  func (device *Device) IpcSet(uapiConf string) error {
   410  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   411  }
   412  
   413  func (device *Device) IpcHandle(socket net.Conn) {
   414  	defer socket.Close()
   415  
   416  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   417  		reader := bufio.NewReader(s)
   418  		writer := bufio.NewWriter(s)
   419  		return bufio.NewReadWriter(reader, writer)
   420  	}(socket)
   421  
   422  	for {
   423  		op, err := buffered.ReadString('\n')
   424  		if err != nil {
   425  			return
   426  		}
   427  
   428  		// handle operation
   429  		switch op {
   430  		case "set=1\n":
   431  			err = device.IpcSetOperation(buffered.Reader)
   432  		case "get=1\n":
   433  			var nextByte byte
   434  			nextByte, err = buffered.ReadByte()
   435  			if err != nil {
   436  				return
   437  			}
   438  			if nextByte != '\n' {
   439  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   440  				break
   441  			}
   442  			err = device.IpcGetOperation(buffered.Writer)
   443  		default:
   444  			device.log.Errorf("invalid UAPI operation: %v", op)
   445  			return
   446  		}
   447  
   448  		// write status
   449  		var status *IPCError
   450  		if err != nil && !errors.As(err, &status) {
   451  			// shouldn't happen
   452  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   453  		}
   454  		if status != nil {
   455  			device.log.Errorf("%v", status)
   456  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   457  		} else {
   458  			fmt.Fprintf(buffered, "errno=0\n\n")
   459  		}
   460  		buffered.Flush()
   461  	}
   462  }