gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/uapi.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bufio"
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/netip"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  
    21  	"gitee.com/aurawing/surguard-go/ipc"
    22  )
    23  
    24  type IPCError struct {
    25  	code int64 // error code
    26  	err  error // underlying/wrapped error
    27  }
    28  
    29  func (s IPCError) Error() string {
    30  	return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
    31  }
    32  
    33  func (s IPCError) Unwrap() error {
    34  	return s.err
    35  }
    36  
    37  func (s IPCError) ErrorCode() int64 {
    38  	return s.code
    39  }
    40  
    41  func ipcErrorf(code int64, msg string, args ...any) *IPCError {
    42  	return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
    43  }
    44  
    45  var byteBufferPool = &sync.Pool{
    46  	New: func() any { return new(bytes.Buffer) },
    47  }
    48  
    49  // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
    50  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
    51  func (device *Device) IpcGetOperation(w io.Writer) error {
    52  	device.ipcMutex.RLock()
    53  	defer device.ipcMutex.RUnlock()
    54  
    55  	buf := byteBufferPool.Get().(*bytes.Buffer)
    56  	buf.Reset()
    57  	defer byteBufferPool.Put(buf)
    58  	sendf := func(format string, args ...any) {
    59  		fmt.Fprintf(buf, format, args...)
    60  		buf.WriteByte('\n')
    61  	}
    62  	keyf := func(prefix string, key *[32]byte) {
    63  		buf.Grow(len(key)*2 + 2 + len(prefix))
    64  		buf.WriteString(prefix)
    65  		buf.WriteByte('=')
    66  		const hex = "0123456789abcdef"
    67  		for i := 0; i < len(key); i++ {
    68  			buf.WriteByte(hex[key[i]>>4])
    69  			buf.WriteByte(hex[key[i]&0xf])
    70  		}
    71  		buf.WriteByte('\n')
    72  	}
    73  	keyp := func(prefix string, key *[33]byte) {
    74  		buf.Grow(len(key)*2 + 2 + len(prefix))
    75  		buf.WriteString(prefix)
    76  		buf.WriteByte('=')
    77  		const hex = "0123456789abcdef"
    78  		for i := 0; i < len(key); i++ {
    79  			buf.WriteByte(hex[key[i]>>4])
    80  			buf.WriteByte(hex[key[i]&0xf])
    81  		}
    82  		buf.WriteByte('\n')
    83  	}
    84  
    85  	func() {
    86  		// lock required resources
    87  
    88  		device.net.RLock()
    89  		defer device.net.RUnlock()
    90  
    91  		device.staticIdentity.RLock()
    92  		defer device.staticIdentity.RUnlock()
    93  
    94  		device.peers.RLock()
    95  		defer device.peers.RUnlock()
    96  
    97  		// serialize device related values
    98  
    99  		if !device.staticIdentity.privateKey.IsZero() {
   100  			keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
   101  		}
   102  
   103  		if device.net.port != 0 {
   104  			sendf("listen_port=%d", device.net.port)
   105  		}
   106  
   107  		if device.net.fwmark != 0 {
   108  			sendf("fwmark=%d", device.net.fwmark)
   109  		}
   110  
   111  		for _, peer := range device.peers.keyMap {
   112  			// Serialize peer state.
   113  			peer.handshake.mutex.RLock()
   114  			keyp("public_key", (*[33]byte)(&peer.handshake.remoteStatic))
   115  			keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
   116  			peer.handshake.mutex.RUnlock()
   117  			sendf("protocol_version=1")
   118  			peer.endpoint.Lock()
   119  			if peer.endpoint.val != nil {
   120  				sendf("endpoint=%s", peer.endpoint.val.DstToString())
   121  			}
   122  			peer.endpoint.Unlock()
   123  
   124  			nano := peer.lastHandshakeNano.Load()
   125  			secs := nano / time.Second.Nanoseconds()
   126  			nano %= time.Second.Nanoseconds()
   127  
   128  			sendf("last_handshake_time_sec=%d", secs)
   129  			sendf("last_handshake_time_nsec=%d", nano)
   130  			sendf("tx_bytes=%d", peer.txBytes.Load())
   131  			sendf("rx_bytes=%d", peer.rxBytes.Load())
   132  			sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
   133  
   134  			device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
   135  				sendf("allowed_ip=%s", prefix.String())
   136  				return true
   137  			})
   138  		}
   139  	}()
   140  
   141  	// send lines (does not require resource locks)
   142  	if _, err := w.Write(buf.Bytes()); err != nil {
   143  		return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
   144  	}
   145  
   146  	return nil
   147  }
   148  
   149  // IpcSetOperation implements the WireGuard configuration protocol "set" operation.
   150  // See https://www.wireguard.com/xplatform/#configuration-protocol for details.
   151  func (device *Device) IpcSetOperation(r io.Reader) (err error) {
   152  	device.ipcMutex.Lock()
   153  	defer device.ipcMutex.Unlock()
   154  
   155  	defer func() {
   156  		if err != nil {
   157  			device.log.Errorf("%v", err)
   158  		}
   159  	}()
   160  
   161  	peer := new(ipcSetPeer)
   162  	deviceConfig := true
   163  
   164  	scanner := bufio.NewScanner(r)
   165  	for scanner.Scan() {
   166  		line := scanner.Text()
   167  		if line == "" {
   168  			// Blank line means terminate operation.
   169  			peer.handlePostConfig()
   170  			return nil
   171  		}
   172  		key, value, ok := strings.Cut(line, "=")
   173  		if !ok {
   174  			return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
   175  		}
   176  
   177  		if key == "public_key" {
   178  			if deviceConfig {
   179  				deviceConfig = false
   180  			}
   181  			peer.handlePostConfig()
   182  			// Load/create the peer we are now configuring.
   183  			err := device.handlePublicKeyLine(peer, value)
   184  			if err != nil {
   185  				return err
   186  			}
   187  			continue
   188  		}
   189  
   190  		var err error
   191  		if deviceConfig {
   192  			err = device.handleDeviceLine(key, value)
   193  		} else {
   194  			err = device.handlePeerLine(peer, key, value)
   195  		}
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  	peer.handlePostConfig()
   201  
   202  	if err := scanner.Err(); err != nil {
   203  		return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
   204  	}
   205  	return nil
   206  }
   207  
   208  func (device *Device) handleDeviceLine(key, value string) error {
   209  	switch key {
   210  	case "private_key":
   211  		var sk NoisePrivateKey
   212  		err := sk.FromMaybeZeroHex(value)
   213  		if err != nil {
   214  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
   215  		}
   216  		device.log.Verbosef("UAPI: Updating private key")
   217  		device.SetPrivateKey(sk)
   218  
   219  	case "listen_port":
   220  		port, err := strconv.ParseUint(value, 10, 16)
   221  		if err != nil {
   222  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
   223  		}
   224  
   225  		// update port and rebind
   226  		device.log.Verbosef("UAPI: Updating listen port")
   227  
   228  		device.net.Lock()
   229  		device.net.port = uint16(port)
   230  		device.net.Unlock()
   231  
   232  		if err := device.BindUpdate(); err != nil {
   233  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
   234  		}
   235  
   236  	case "fwmark":
   237  		mark, err := strconv.ParseUint(value, 10, 32)
   238  		if err != nil {
   239  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
   240  		}
   241  
   242  		device.log.Verbosef("UAPI: Updating fwmark")
   243  		if err := device.BindSetMark(uint32(mark)); err != nil {
   244  			return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
   245  		}
   246  
   247  	case "replace_peers":
   248  		if value != "true" {
   249  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
   250  		}
   251  		device.log.Verbosef("UAPI: Removing all peers")
   252  		device.RemoveAllPeers()
   253  
   254  	default:
   255  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
   256  	}
   257  
   258  	return nil
   259  }
   260  
   261  // An ipcSetPeer is the current state of an IPC set operation on a peer.
   262  type ipcSetPeer struct {
   263  	*Peer        // Peer is the current peer being operated on
   264  	dummy   bool // dummy reports whether this peer is a temporary, placeholder peer
   265  	created bool // new reports whether this is a newly created peer
   266  	pkaOn   bool // pkaOn reports whether the peer had the persistent keepalive turn on
   267  }
   268  
   269  func (peer *ipcSetPeer) handlePostConfig() {
   270  	if peer.Peer == nil || peer.dummy {
   271  		return
   272  	}
   273  	if peer.created {
   274  		peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
   275  	}
   276  	if peer.device.isUp() {
   277  		peer.Start()
   278  		if peer.pkaOn {
   279  			peer.SendKeepalive()
   280  		}
   281  		peer.SendStagedPackets()
   282  	}
   283  }
   284  
   285  func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
   286  	// Load/create the peer we are configuring.
   287  	var publicKey NoisePublicKey
   288  	err := publicKey.FromHex(value)
   289  	if err != nil {
   290  		return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
   291  	}
   292  
   293  	// Ignore peer with the same public key as this device.
   294  	device.staticIdentity.RLock()
   295  	peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
   296  	device.staticIdentity.RUnlock()
   297  
   298  	if peer.dummy {
   299  		peer.Peer = &Peer{}
   300  	} else {
   301  		peer.Peer = device.LookupPeer(publicKey)
   302  	}
   303  
   304  	peer.created = peer.Peer == nil
   305  	if peer.created {
   306  		peer.Peer, err = device.NewPeer(publicKey)
   307  		if err != nil {
   308  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
   309  		}
   310  		device.log.Verbosef("%v - UAPI: Created", peer.Peer)
   311  	}
   312  	return nil
   313  }
   314  
   315  func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
   316  	switch key {
   317  	case "update_only":
   318  		// allow disabling of creation
   319  		if value != "true" {
   320  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
   321  		}
   322  		if peer.created && !peer.dummy {
   323  			device.RemovePeer(peer.handshake.remoteStatic)
   324  			peer.Peer = &Peer{}
   325  			peer.dummy = true
   326  		}
   327  
   328  	case "remove":
   329  		// remove currently selected peer from device
   330  		if value != "true" {
   331  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
   332  		}
   333  		if !peer.dummy {
   334  			device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
   335  			device.RemovePeer(peer.handshake.remoteStatic)
   336  		}
   337  		peer.Peer = &Peer{}
   338  		peer.dummy = true
   339  
   340  	case "preshared_key":
   341  		device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
   342  
   343  		peer.handshake.mutex.Lock()
   344  		err := peer.handshake.presharedKey.FromHex(value)
   345  		peer.handshake.mutex.Unlock()
   346  
   347  		if err != nil {
   348  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
   349  		}
   350  
   351  	case "endpoint":
   352  		device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
   353  		endpoint, err := device.net.bind.ParseEndpoint(value)
   354  		if err != nil {
   355  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
   356  		}
   357  		peer.endpoint.Lock()
   358  		defer peer.endpoint.Unlock()
   359  		peer.endpoint.val = endpoint
   360  
   361  	case "persistent_keepalive_interval":
   362  		device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
   363  
   364  		secs, err := strconv.ParseUint(value, 10, 16)
   365  		if err != nil {
   366  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
   367  		}
   368  
   369  		old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
   370  
   371  		// Send immediate keepalive if we're turning it on and before it wasn't on.
   372  		peer.pkaOn = old == 0 && secs != 0
   373  
   374  	case "replace_allowed_ips":
   375  		device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
   376  		if value != "true" {
   377  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
   378  		}
   379  		if peer.dummy {
   380  			return nil
   381  		}
   382  		device.allowedips.RemoveByPeer(peer.Peer)
   383  
   384  	case "allowed_ip":
   385  		device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
   386  		prefix, err := netip.ParsePrefix(value)
   387  		if err != nil {
   388  			return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
   389  		}
   390  		if peer.dummy {
   391  			return nil
   392  		}
   393  		device.allowedips.Insert(prefix, peer.Peer)
   394  
   395  	case "protocol_version":
   396  		if value != "1" {
   397  			return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
   398  		}
   399  
   400  	default:
   401  		return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
   402  	}
   403  
   404  	return nil
   405  }
   406  
   407  func (device *Device) IpcGet() (string, error) {
   408  	buf := new(strings.Builder)
   409  	if err := device.IpcGetOperation(buf); err != nil {
   410  		return "", err
   411  	}
   412  	return buf.String(), nil
   413  }
   414  
   415  func (device *Device) IpcSet(uapiConf string) error {
   416  	return device.IpcSetOperation(strings.NewReader(uapiConf))
   417  }
   418  
   419  func (device *Device) IpcHandle(socket net.Conn) {
   420  	defer socket.Close()
   421  
   422  	buffered := func(s io.ReadWriter) *bufio.ReadWriter {
   423  		reader := bufio.NewReader(s)
   424  		writer := bufio.NewWriter(s)
   425  		return bufio.NewReadWriter(reader, writer)
   426  	}(socket)
   427  
   428  	for {
   429  		op, err := buffered.ReadString('\n')
   430  		if err != nil {
   431  			return
   432  		}
   433  
   434  		// handle operation
   435  		switch op {
   436  		case "set=1\n":
   437  			err = device.IpcSetOperation(buffered.Reader)
   438  		case "get=1\n":
   439  			var nextByte byte
   440  			nextByte, err = buffered.ReadByte()
   441  			if err != nil {
   442  				return
   443  			}
   444  			if nextByte != '\n' {
   445  				err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
   446  				break
   447  			}
   448  			err = device.IpcGetOperation(buffered.Writer)
   449  		default:
   450  			device.log.Errorf("invalid UAPI operation: %v", op)
   451  			return
   452  		}
   453  
   454  		// write status
   455  		var status *IPCError
   456  		if err != nil && !errors.As(err, &status) {
   457  			// shouldn't happen
   458  			status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
   459  		}
   460  		if status != nil {
   461  			device.log.Errorf("%v", status)
   462  			fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
   463  		} else {
   464  			fmt.Fprintf(buffered, "errno=0\n\n")
   465  		}
   466  		buffered.Flush()
   467  	}
   468  }