github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/uapi.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bufio" 10 "bytes" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "net/netip" 16 "strconv" 17 "strings" 18 "sync" 19 "time" 20 21 "github.com/amnezia-vpn/amnezia-wg/ipc" 22 ) 23 24 type IPCError struct { 25 code int64 // error code 26 err error // underlying/wrapped error 27 } 28 29 func (s IPCError) Error() string { 30 return fmt.Sprintf("IPC error %d: %v", s.code, s.err) 31 } 32 33 func (s IPCError) Unwrap() error { 34 return s.err 35 } 36 37 func (s IPCError) ErrorCode() int64 { 38 return s.code 39 } 40 41 func ipcErrorf(code int64, msg string, args ...any) *IPCError { 42 return &IPCError{code: code, err: fmt.Errorf(msg, args...)} 43 } 44 45 var byteBufferPool = &sync.Pool{ 46 New: func() any { return new(bytes.Buffer) }, 47 } 48 49 // IpcGetOperation implements the WireGuard configuration protocol "get" operation. 50 // See https://www.wireguard.com/xplatform/#configuration-protocol for details. 51 func (device *Device) IpcGetOperation(w io.Writer) error { 52 device.ipcMutex.RLock() 53 defer device.ipcMutex.RUnlock() 54 55 buf := byteBufferPool.Get().(*bytes.Buffer) 56 buf.Reset() 57 defer byteBufferPool.Put(buf) 58 sendf := func(format string, args ...any) { 59 fmt.Fprintf(buf, format, args...) 60 buf.WriteByte('\n') 61 } 62 keyf := func(prefix string, key *[32]byte) { 63 buf.Grow(len(key)*2 + 2 + len(prefix)) 64 buf.WriteString(prefix) 65 buf.WriteByte('=') 66 const hex = "0123456789abcdef" 67 for i := 0; i < len(key); i++ { 68 buf.WriteByte(hex[key[i]>>4]) 69 buf.WriteByte(hex[key[i]&0xf]) 70 } 71 buf.WriteByte('\n') 72 } 73 74 func() { 75 // lock required resources 76 77 device.net.RLock() 78 defer device.net.RUnlock() 79 80 device.staticIdentity.RLock() 81 defer device.staticIdentity.RUnlock() 82 83 device.peers.RLock() 84 defer device.peers.RUnlock() 85 86 // serialize device related values 87 88 if !device.staticIdentity.privateKey.IsZero() { 89 keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) 90 } 91 92 if device.net.port != 0 { 93 sendf("listen_port=%d", device.net.port) 94 } 95 96 if device.net.fwmark != 0 { 97 sendf("fwmark=%d", device.net.fwmark) 98 } 99 100 if device.isAdvancedSecurityOn() { 101 if device.aSecCfg.junkPacketCount != 0 { 102 sendf("jc=%d", device.aSecCfg.junkPacketCount) 103 } 104 if device.aSecCfg.junkPacketMinSize != 0 { 105 sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) 106 } 107 if device.aSecCfg.junkPacketMaxSize != 0 { 108 sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) 109 } 110 if device.aSecCfg.initPacketJunkSize != 0 { 111 sendf("s1=%d", device.aSecCfg.initPacketJunkSize) 112 } 113 if device.aSecCfg.responsePacketJunkSize != 0 { 114 sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) 115 } 116 if device.aSecCfg.initPacketMagicHeader != 0 { 117 sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) 118 } 119 if device.aSecCfg.responsePacketMagicHeader != 0 { 120 sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) 121 } 122 if device.aSecCfg.underloadPacketMagicHeader != 0 { 123 sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) 124 } 125 if device.aSecCfg.transportPacketMagicHeader != 0 { 126 sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) 127 } 128 } 129 130 for _, peer := range device.peers.keyMap { 131 // Serialize peer state. 132 peer.handshake.mutex.RLock() 133 keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) 134 keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) 135 peer.handshake.mutex.RUnlock() 136 sendf("protocol_version=1") 137 peer.endpoint.Lock() 138 if peer.endpoint.val != nil { 139 sendf("endpoint=%s", peer.endpoint.val.DstToString()) 140 } 141 peer.endpoint.Unlock() 142 143 nano := peer.lastHandshakeNano.Load() 144 secs := nano / time.Second.Nanoseconds() 145 nano %= time.Second.Nanoseconds() 146 147 sendf("last_handshake_time_sec=%d", secs) 148 sendf("last_handshake_time_nsec=%d", nano) 149 sendf("tx_bytes=%d", peer.txBytes.Load()) 150 sendf("rx_bytes=%d", peer.rxBytes.Load()) 151 sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) 152 153 device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { 154 sendf("allowed_ip=%s", prefix.String()) 155 return true 156 }) 157 } 158 }() 159 160 // send lines (does not require resource locks) 161 if _, err := w.Write(buf.Bytes()); err != nil { 162 return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) 163 } 164 165 return nil 166 } 167 168 // IpcSetOperation implements the WireGuard configuration protocol "set" operation. 169 // See https://www.wireguard.com/xplatform/#configuration-protocol for details. 170 func (device *Device) IpcSetOperation(r io.Reader) (err error) { 171 device.ipcMutex.Lock() 172 defer device.ipcMutex.Unlock() 173 174 defer func() { 175 if err != nil { 176 device.log.Errorf("%v", err) 177 } 178 }() 179 180 peer := new(ipcSetPeer) 181 deviceConfig := true 182 183 tempASecCfg := aSecCfgType{} 184 scanner := bufio.NewScanner(r) 185 for scanner.Scan() { 186 line := scanner.Text() 187 if line == "" { 188 // Blank line means terminate operation. 189 err := device.handlePostConfig(&tempASecCfg) 190 if err != nil { 191 return err 192 } 193 peer.handlePostConfig() 194 return nil 195 } 196 key, value, ok := strings.Cut(line, "=") 197 if !ok { 198 return ipcErrorf( 199 ipc.IpcErrorProtocol, 200 "failed to parse line %q", 201 line, 202 ) 203 } 204 205 if key == "public_key" { 206 if deviceConfig { 207 deviceConfig = false 208 } 209 peer.handlePostConfig() 210 // Load/create the peer we are now configuring. 211 err := device.handlePublicKeyLine(peer, value) 212 if err != nil { 213 return err 214 } 215 continue 216 } 217 218 var err error 219 if deviceConfig { 220 err = device.handleDeviceLine(key, value, &tempASecCfg) 221 } else { 222 err = device.handlePeerLine(peer, key, value) 223 } 224 if err != nil { 225 return err 226 } 227 } 228 err = device.handlePostConfig(&tempASecCfg) 229 if err != nil { 230 return err 231 } 232 peer.handlePostConfig() 233 234 if err := scanner.Err(); err != nil { 235 return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) 236 } 237 return nil 238 } 239 240 func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { 241 switch key { 242 case "private_key": 243 var sk NoisePrivateKey 244 err := sk.FromMaybeZeroHex(value) 245 if err != nil { 246 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) 247 } 248 device.log.Verbosef("UAPI: Updating private key") 249 device.SetPrivateKey(sk) 250 251 case "listen_port": 252 port, err := strconv.ParseUint(value, 10, 16) 253 if err != nil { 254 return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) 255 } 256 257 // update port and rebind 258 device.log.Verbosef("UAPI: Updating listen port") 259 260 device.net.Lock() 261 device.net.port = uint16(port) 262 device.net.Unlock() 263 264 if err := device.BindUpdate(); err != nil { 265 return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) 266 } 267 268 case "fwmark": 269 mark, err := strconv.ParseUint(value, 10, 32) 270 if err != nil { 271 return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) 272 } 273 274 device.log.Verbosef("UAPI: Updating fwmark") 275 if err := device.BindSetMark(uint32(mark)); err != nil { 276 return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) 277 } 278 279 case "replace_peers": 280 if value != "true" { 281 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) 282 } 283 device.log.Verbosef("UAPI: Removing all peers") 284 device.RemoveAllPeers() 285 286 case "jc": 287 junkPacketCount, err := strconv.Atoi(value) 288 if err != nil { 289 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) 290 } 291 device.log.Verbosef("UAPI: Updating junk_packet_count") 292 tempASecCfg.junkPacketCount = junkPacketCount 293 tempASecCfg.isSet = true 294 295 case "jmin": 296 junkPacketMinSize, err := strconv.Atoi(value) 297 if err != nil { 298 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) 299 } 300 device.log.Verbosef("UAPI: Updating junk_packet_min_size") 301 tempASecCfg.junkPacketMinSize = junkPacketMinSize 302 tempASecCfg.isSet = true 303 304 case "jmax": 305 junkPacketMaxSize, err := strconv.Atoi(value) 306 if err != nil { 307 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) 308 } 309 device.log.Verbosef("UAPI: Updating junk_packet_max_size") 310 tempASecCfg.junkPacketMaxSize = junkPacketMaxSize 311 tempASecCfg.isSet = true 312 313 case "s1": 314 initPacketJunkSize, err := strconv.Atoi(value) 315 if err != nil { 316 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) 317 } 318 device.log.Verbosef("UAPI: Updating init_packet_junk_size") 319 tempASecCfg.initPacketJunkSize = initPacketJunkSize 320 tempASecCfg.isSet = true 321 322 case "s2": 323 responsePacketJunkSize, err := strconv.Atoi(value) 324 if err != nil { 325 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) 326 } 327 device.log.Verbosef("UAPI: Updating response_packet_junk_size") 328 tempASecCfg.responsePacketJunkSize = responsePacketJunkSize 329 tempASecCfg.isSet = true 330 331 case "h1": 332 initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) 333 if err != nil { 334 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) 335 } 336 tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) 337 tempASecCfg.isSet = true 338 339 case "h2": 340 responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) 341 if err != nil { 342 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) 343 } 344 tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) 345 tempASecCfg.isSet = true 346 347 case "h3": 348 underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) 349 if err != nil { 350 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) 351 } 352 tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) 353 tempASecCfg.isSet = true 354 355 case "h4": 356 transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) 357 if err != nil { 358 return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) 359 } 360 tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) 361 tempASecCfg.isSet = true 362 363 default: 364 return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) 365 } 366 367 return nil 368 } 369 370 // An ipcSetPeer is the current state of an IPC set operation on a peer. 371 type ipcSetPeer struct { 372 *Peer // Peer is the current peer being operated on 373 dummy bool // dummy reports whether this peer is a temporary, placeholder peer 374 created bool // new reports whether this is a newly created peer 375 pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on 376 } 377 378 func (peer *ipcSetPeer) handlePostConfig() { 379 if peer.Peer == nil || peer.dummy { 380 return 381 } 382 if peer.created { 383 peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil 384 } 385 if peer.device.isUp() { 386 peer.Start() 387 if peer.pkaOn { 388 peer.SendKeepalive() 389 } 390 peer.SendStagedPackets() 391 } 392 } 393 394 func (device *Device) handlePublicKeyLine( 395 peer *ipcSetPeer, 396 value string, 397 ) error { 398 // Load/create the peer we are configuring. 399 var publicKey NoisePublicKey 400 err := publicKey.FromHex(value) 401 if err != nil { 402 return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) 403 } 404 405 // Ignore peer with the same public key as this device. 406 device.staticIdentity.RLock() 407 peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) 408 device.staticIdentity.RUnlock() 409 410 if peer.dummy { 411 peer.Peer = &Peer{} 412 } else { 413 peer.Peer = device.LookupPeer(publicKey) 414 } 415 416 peer.created = peer.Peer == nil 417 if peer.created { 418 peer.Peer, err = device.NewPeer(publicKey) 419 if err != nil { 420 return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) 421 } 422 device.log.Verbosef("%v - UAPI: Created", peer.Peer) 423 } 424 return nil 425 } 426 427 func (device *Device) handlePeerLine( 428 peer *ipcSetPeer, 429 key, value string, 430 ) error { 431 switch key { 432 case "update_only": 433 // allow disabling of creation 434 if value != "true" { 435 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) 436 } 437 if peer.created && !peer.dummy { 438 device.RemovePeer(peer.handshake.remoteStatic) 439 peer.Peer = &Peer{} 440 peer.dummy = true 441 } 442 443 case "remove": 444 // remove currently selected peer from device 445 if value != "true" { 446 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) 447 } 448 if !peer.dummy { 449 device.log.Verbosef("%v - UAPI: Removing", peer.Peer) 450 device.RemovePeer(peer.handshake.remoteStatic) 451 } 452 peer.Peer = &Peer{} 453 peer.dummy = true 454 455 case "preshared_key": 456 device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) 457 458 peer.handshake.mutex.Lock() 459 err := peer.handshake.presharedKey.FromHex(value) 460 peer.handshake.mutex.Unlock() 461 462 if err != nil { 463 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) 464 } 465 466 case "endpoint": 467 device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) 468 endpoint, err := device.net.bind.ParseEndpoint(value) 469 if err != nil { 470 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) 471 } 472 peer.endpoint.Lock() 473 defer peer.endpoint.Unlock() 474 peer.endpoint.val = endpoint 475 476 case "persistent_keepalive_interval": 477 device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) 478 479 secs, err := strconv.ParseUint(value, 10, 16) 480 if err != nil { 481 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) 482 } 483 484 old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) 485 486 // Send immediate keepalive if we're turning it on and before it wasn't on. 487 peer.pkaOn = old == 0 && secs != 0 488 489 case "replace_allowed_ips": 490 device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) 491 if value != "true" { 492 return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) 493 } 494 if peer.dummy { 495 return nil 496 } 497 device.allowedips.RemoveByPeer(peer.Peer) 498 499 case "allowed_ip": 500 device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) 501 prefix, err := netip.ParsePrefix(value) 502 if err != nil { 503 return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) 504 } 505 if peer.dummy { 506 return nil 507 } 508 device.allowedips.Insert(prefix, peer.Peer) 509 510 case "protocol_version": 511 if value != "1" { 512 return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) 513 } 514 515 default: 516 return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) 517 } 518 519 return nil 520 } 521 522 func (device *Device) IpcGet() (string, error) { 523 buf := new(strings.Builder) 524 if err := device.IpcGetOperation(buf); err != nil { 525 return "", err 526 } 527 return buf.String(), nil 528 } 529 530 func (device *Device) IpcSet(uapiConf string) error { 531 return device.IpcSetOperation(strings.NewReader(uapiConf)) 532 } 533 534 func (device *Device) IpcHandle(socket net.Conn) { 535 defer socket.Close() 536 537 buffered := func(s io.ReadWriter) *bufio.ReadWriter { 538 reader := bufio.NewReader(s) 539 writer := bufio.NewWriter(s) 540 return bufio.NewReadWriter(reader, writer) 541 }(socket) 542 543 for { 544 op, err := buffered.ReadString('\n') 545 if err != nil { 546 return 547 } 548 549 // handle operation 550 switch op { 551 case "set=1\n": 552 err = device.IpcSetOperation(buffered.Reader) 553 case "get=1\n": 554 var nextByte byte 555 nextByte, err = buffered.ReadByte() 556 if err != nil { 557 return 558 } 559 if nextByte != '\n' { 560 err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) 561 break 562 } 563 err = device.IpcGetOperation(buffered.Writer) 564 default: 565 device.log.Errorf("invalid UAPI operation: %v", op) 566 return 567 } 568 569 // write status 570 var status *IPCError 571 if err != nil && !errors.As(err, &status) { 572 // shouldn't happen 573 status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) 574 } 575 if status != nil { 576 device.log.Errorf("%v", status) 577 fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) 578 } else { 579 fmt.Fprintf(buffered, "errno=0\n\n") 580 } 581 buffered.Flush() 582 } 583 }