github.com/GFW-knocker/wireguard@v1.0.1/device/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/netip"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/GFW-knocker/wireguard/ipc"
    22  )
    23  
    24  type IPCError struct {
    25  	code int64 // error code
    26  	err  error // underlying/wrapped error
    27  }
    28  
    29  func (s IPCError) Error() string {
    30  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    31  }
    32  
    33  func (s IPCError) Unwrap() error {
    34  	return s.err
    35  }
    36  
    37  func (s IPCError) ErrorCode() int64 {
    38  	return s.code
    39  }
    40  
    41  func ipcErrorf(code int64, msg string, args ...any) *IPCError {
    42  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    43  }
    44  
    45  var byteBufferPool = &sync.Pool{
    46  	New: func() any { return new(bytes.Buffer) },
    47  }
    48  
    49  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    50  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    51  func (device *Device) IpcGetOperation(w io.Writer) error {
    52  	device.ipcMutex.RLock()
    53  	defer device.ipcMutex.RUnlock()
    54  
    55  	buf := byteBufferPool.Get().(*bytes.Buffer)
    56  	buf.Reset()
    57  	defer byteBufferPool.Put(buf)
    58  	sendf := func(format string, args ...any) {
    59  		fmt.Fprintf(buf, format, args...)
    60  		buf.WriteByte('\n')
    61  	}
    62  	keyf := func(prefix string, key *[32]byte) {
    63  		buf.Grow(len(key)*2 + 2 + len(prefix))
    64  		buf.WriteString(prefix)
    65  		buf.WriteByte('=')
    66  		const hex = "0123456789abcdef"
    67  		for i := 0; i < len(key); i++ {
    68  			buf.WriteByte(hex[key[i]>>4])
    69  			buf.WriteByte(hex[key[i]&0xf])
    70  		}
    71  		buf.WriteByte('\n')
    72  	}
    73  
    74  	func() {
    75  		// lock required resources
    76  
    77  		device.net.RLock()
    78  		defer device.net.RUnlock()
    79  
    80  		device.staticIdentity.RLock()
    81  		defer device.staticIdentity.RUnlock()
    82  
    83  		device.peers.RLock()
    84  		defer device.peers.RUnlock()
    85  
    86  		// serialize device related values
    87  
    88  		if !device.staticIdentity.privateKey.IsZero() {
    89  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
    90  		}
    91  
    92  		if device.net.port != 0 {
    93  			sendf("listen_port=%d", device.net.port)
    94  		}
    95  
    96  		if device.net.fwmark != 0 {
    97  			sendf("fwmark=%d", device.net.fwmark)
    98  		}
    99  
   100  		for _, peer := range device.peers.keyMap {
   101  			// Serialize peer state.
   102  			peer.handshake.mutex.RLock()
   103  			keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
   104  			keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   105  			peer.handshake.mutex.RUnlock()
   106  			sendf("protocol_version=1")
   107  			peer.endpoint.Lock()
   108  			if peer.endpoint.val != nil {
   109  				sendf("endpoint=%s", peer.endpoint.val.DstToString())
   110  			}
   111  			peer.endpoint.Unlock()
   112  
   113  			nano := peer.lastHandshakeNano.Load()
   114  			secs := nano / time.Second.Nanoseconds()
   115  			nano %= time.Second.Nanoseconds()
   116  
   117  			sendf("last_handshake_time_sec=%d", secs)
   118  			sendf("last_handshake_time_nsec=%d", nano)
   119  			sendf("tx_bytes=%d", peer.txBytes.Load())
   120  			sendf("rx_bytes=%d", peer.rxBytes.Load())
   121  			sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
   122  
   123  			device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   124  				sendf("allowed_ip=%s", prefix.String())
   125  				return true
   126  			})
   127  		}
   128  	}()
   129  
   130  	// send lines (does not require resource locks)
   131  	if _, err := w.Write(buf.Bytes()); err != nil {
   132  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   139  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   140  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   141  	device.ipcMutex.Lock()
   142  	defer device.ipcMutex.Unlock()
   143  
   144  	defer func() {
   145  		if err != nil {
   146  			device.log.Errorf("%v", err)
   147  		}
   148  	}()
   149  
   150  	peer := new(ipcSetPeer)
   151  	deviceConfig := true
   152  
   153  	scanner := bufio.NewScanner(r)
   154  	for scanner.Scan() {
   155  		line := scanner.Text()
   156  		if line == "" {
   157  			// Blank line means terminate operation.
   158  			peer.handlePostConfig()
   159  			return nil
   160  		}
   161  		key, value, ok := strings.Cut(line, "=")
   162  		if !ok {
   163  			return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
   164  		}
   165  
   166  		if key == "public_key" {
   167  			if deviceConfig {
   168  				deviceConfig = false
   169  			}
   170  			peer.handlePostConfig()
   171  			// Load/create the peer we are now configuring.
   172  			err := device.handlePublicKeyLine(peer, value)
   173  			if err != nil {
   174  				return err
   175  			}
   176  			continue
   177  		}
   178  
   179  		var err error
   180  		if deviceConfig {
   181  			err = device.handleDeviceLine(key, value)
   182  		} else {
   183  			err = device.handlePeerLine(peer, key, value)
   184  		}
   185  		if err != nil {
   186  			return err
   187  		}
   188  	}
   189  	peer.handlePostConfig()
   190  
   191  	if err := scanner.Err(); err != nil {
   192  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   193  	}
   194  	return nil
   195  }
   196  
   197  func (device *Device) handleDeviceLine(key, value string) error {
   198  	switch key {
   199  	case "private_key":
   200  		var sk NoisePrivateKey
   201  		err := sk.FromMaybeZeroHex(value)
   202  		if err != nil {
   203  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   204  		}
   205  		device.log.Verbosef("UAPI: Updating private key")
   206  		device.SetPrivateKey(sk)
   207  
   208  	case "listen_port":
   209  		port, err := strconv.ParseUint(value, 10, 16)
   210  		if err != nil {
   211  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   212  		}
   213  
   214  		// update port and rebind
   215  		device.log.Verbosef("UAPI: Updating listen port")
   216  
   217  		device.net.Lock()
   218  		device.net.port = uint16(port)
   219  		device.net.Unlock()
   220  
   221  		if err := device.BindUpdate(); err != nil {
   222  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   223  		}
   224  
   225  	case "fwmark":
   226  		mark, err := strconv.ParseUint(value, 10, 32)
   227  		if err != nil {
   228  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   229  		}
   230  
   231  		device.log.Verbosef("UAPI: Updating fwmark")
   232  		if err := device.BindSetMark(uint32(mark)); err != nil {
   233  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   234  		}
   235  
   236  	case "replace_peers":
   237  		if value != "true" {
   238  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   239  		}
   240  		device.log.Verbosef("UAPI: Removing all peers")
   241  		device.RemoveAllPeers()
   242  
   243  	default:
   244  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   251  type ipcSetPeer struct {
   252  	*Peer        // Peer is the current peer being operated on
   253  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   254  	created bool // new reports whether this is a newly created peer
   255  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   256  }
   257  
   258  func (peer *ipcSetPeer) handlePostConfig() {
   259  	if peer.Peer == nil || peer.dummy {
   260  		return
   261  	}
   262  	if peer.created {
   263  		peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
   264  	}
   265  	if peer.device.isUp() {
   266  		peer.Start()
   267  		if peer.pkaOn {
   268  			peer.SendKeepalive()
   269  		}
   270  		peer.SendStagedPackets()
   271  	}
   272  }
   273  
   274  func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
   275  	// Load/create the peer we are configuring.
   276  	var publicKey NoisePublicKey
   277  	err := publicKey.FromHex(value)
   278  	if err != nil {
   279  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   280  	}
   281  
   282  	// Ignore peer with the same public key as this device.
   283  	device.staticIdentity.RLock()
   284  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   285  	device.staticIdentity.RUnlock()
   286  
   287  	if peer.dummy {
   288  		peer.Peer = &Peer{}
   289  	} else {
   290  		peer.Peer = device.LookupPeer(publicKey)
   291  	}
   292  
   293  	peer.created = peer.Peer == nil
   294  	if peer.created {
   295  		peer.Peer, err = device.NewPeer(publicKey)
   296  		if err != nil {
   297  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   298  		}
   299  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   300  	}
   301  	return nil
   302  }
   303  
   304  func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
   305  	switch key {
   306  	case "update_only":
   307  		// allow disabling of creation
   308  		if value != "true" {
   309  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   310  		}
   311  		if peer.created && !peer.dummy {
   312  			device.RemovePeer(peer.handshake.remoteStatic)
   313  			peer.Peer = &Peer{}
   314  			peer.dummy = true
   315  		}
   316  
   317  	case "remove":
   318  		// remove currently selected peer from device
   319  		if value != "true" {
   320  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   321  		}
   322  		if !peer.dummy {
   323  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   324  			device.RemovePeer(peer.handshake.remoteStatic)
   325  		}
   326  		peer.Peer = &Peer{}
   327  		peer.dummy = true
   328  
   329  	case "preshared_key":
   330  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   331  
   332  		peer.handshake.mutex.Lock()
   333  		err := peer.handshake.presharedKey.FromHex(value)
   334  		peer.handshake.mutex.Unlock()
   335  
   336  		if err != nil {
   337  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   338  		}
   339  
   340  	case "endpoint":
   341  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   342  		endpoint, err := device.net.bind.ParseEndpoint(value)
   343  		if err != nil {
   344  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   345  		}
   346  		peer.endpoint.Lock()
   347  		defer peer.endpoint.Unlock()
   348  		peer.endpoint.val = endpoint
   349  
   350  	case "persistent_keepalive_interval":
   351  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   352  
   353  		secs, err := strconv.ParseUint(value, 10, 16)
   354  		if err != nil {
   355  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   356  		}
   357  
   358  		old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
   359  
   360  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   361  		peer.pkaOn = old == 0 && secs != 0
   362  
   363  	case "replace_allowed_ips":
   364  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   365  		if value != "true" {
   366  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   367  		}
   368  		if peer.dummy {
   369  			return nil
   370  		}
   371  		device.allowedips.RemoveByPeer(peer.Peer)
   372  
   373  	case "allowed_ip":
   374  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   375  		prefix, err := netip.ParsePrefix(value)
   376  		if err != nil {
   377  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   378  		}
   379  		if peer.dummy {
   380  			return nil
   381  		}
   382  		device.allowedips.Insert(prefix, peer.Peer)
   383  
   384  	case "protocol_version":
   385  		if value != "1" {
   386  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   387  		}
   388  
   389  	default:
   390  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   391  	}
   392  
   393  	return nil
   394  }
   395  
   396  func (device *Device) IpcGet() (string, error) {
   397  	buf := new(strings.Builder)
   398  	if err := device.IpcGetOperation(buf); err != nil {
   399  		return "", err
   400  	}
   401  	return buf.String(), nil
   402  }
   403  
   404  func (device *Device) IpcSet(uapiConf string) error {
   405  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   406  }
   407  
   408  func (device *Device) IpcHandle(socket net.Conn) {
   409  	defer socket.Close()
   410  
   411  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   412  		reader := bufio.NewReader(s)
   413  		writer := bufio.NewWriter(s)
   414  		return bufio.NewReadWriter(reader, writer)
   415  	}(socket)
   416  
   417  	for {
   418  		op, err := buffered.ReadString('\n')
   419  		if err != nil {
   420  			return
   421  		}
   422  
   423  		// handle operation
   424  		switch op {
   425  		case "set=1\n":
   426  			err = device.IpcSetOperation(buffered.Reader)
   427  		case "get=1\n":
   428  			var nextByte byte
   429  			nextByte, err = buffered.ReadByte()
   430  			if err != nil {
   431  				return
   432  			}
   433  			if nextByte != '\n' {
   434  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   435  				break
   436  			}
   437  			err = device.IpcGetOperation(buffered.Writer)
   438  		default:
   439  			device.log.Errorf("invalid UAPI operation: %v", op)
   440  			return
   441  		}
   442  
   443  		// write status
   444  		var status *IPCError
   445  		if err != nil && !errors.As(err, &status) {
   446  			// shouldn't happen
   447  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   448  		}
   449  		if status != nil {
   450  			device.log.Errorf("%v", status)
   451  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   452  		} else {
   453  			fmt.Fprintf(buffered, "errno=0\n\n")
   454  		}
   455  		buffered.Flush()
   456  	}
   457  }