gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "os" 10 "runtime" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "gitee.com/aurawing/surguard-go/conn" 16 "gitee.com/aurawing/surguard-go/ratelimiter" 17 "gitee.com/aurawing/surguard-go/rwcancel" 18 "gitee.com/aurawing/surguard-go/tun" 19 ) 20 21 type Device struct { 22 state struct { 23 // state holds the device's state. It is accessed atomically. 24 // Use the device.deviceState method to read it. 25 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 26 // During state transitions, the state variable is updated before the device itself. 27 // The state is thus either the current state of the device or 28 // the intended future state of the device. 29 // For example, while executing a call to Up, state will be deviceStateUp. 30 // There is no guarantee that that intended future state of the device 31 // will become the actual state; Up can fail. 32 // The device can also change state multiple times between time of check and time of use. 33 // Unsynchronized uses of state must therefore be advisory/best-effort only. 34 state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience 35 // stopping blocks until all inputs to Device have been closed. 36 stopping sync.WaitGroup 37 // mu protects state changes. 38 sync.Mutex 39 } 40 41 net struct { 42 stopping sync.WaitGroup 43 sync.RWMutex 44 bind conn.Bind // bind interface 45 netlinkCancel *rwcancel.RWCancel 46 ipv4Addr string 47 ipv6Addr string 48 port uint16 // listening port 49 fwmark uint32 // mark value (0 = disabled) 50 brokenRoaming bool 51 } 52 53 staticIdentity struct { 54 sync.RWMutex 55 privateKey NoisePrivateKey 56 publicKey NoisePublicKey 57 } 58 59 peers struct { 60 sync.RWMutex // protects keyMap 61 keyMap map[NoisePublicKey]*Peer 62 } 63 64 rate struct { 65 underLoadUntil atomic.Int64 66 limiter ratelimiter.Ratelimiter 67 } 68 69 allowedips AllowedIPs 70 indexTable IndexTable 71 cookieChecker CookieChecker 72 73 pool struct { 74 inboundElementsContainer *WaitPool 75 outboundElementsContainer *WaitPool 76 messageBuffers *WaitPool 77 inboundElements *WaitPool 78 outboundElements *WaitPool 79 } 80 81 queue struct { 82 encryption *outboundQueue 83 decryption *inboundQueue 84 handshake *handshakeQueue 85 } 86 87 tun struct { 88 device tun.Device 89 mtu atomic.Int32 90 } 91 92 ipcMutex sync.RWMutex 93 closed chan struct{} 94 log *Logger 95 } 96 97 // deviceState represents the state of a Device. 98 // There are three states: down, up, closed. 99 // Transitions: 100 // 101 // down -----+ 102 // ↑↓ ↓ 103 // up -> closed 104 type deviceState uint32 105 106 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 107 const ( 108 deviceStateDown deviceState = iota 109 deviceStateUp 110 deviceStateClosed 111 ) 112 113 const ( 114 SECP256K1 int = iota // SECP256K1 = 0 115 SM2 // SM2 = 1 116 ) 117 118 const ENV_SG_ALGO = "SG_ALGO" 119 120 var ALGO int 121 122 func init() { 123 algoStr := os.Getenv(ENV_SG_ALGO) 124 if algoStr == "SECP256K1" { 125 ALGO = SECP256K1 126 } else if algoStr == "SM2" { 127 ALGO = SM2 128 } else { 129 ALGO = SECP256K1 130 } 131 } 132 133 // deviceState returns device.state.state as a deviceState 134 // See those docs for how to interpret this value. 135 func (device *Device) deviceState() deviceState { 136 return deviceState(device.state.state.Load()) 137 } 138 139 // isClosed reports whether the device is closed (or is closing). 140 // See device.state.state comments for how to interpret this value. 141 func (device *Device) isClosed() bool { 142 return device.deviceState() == deviceStateClosed 143 } 144 145 // isUp reports whether the device is up (or is attempting to come up). 146 // See device.state.state comments for how to interpret this value. 147 func (device *Device) isUp() bool { 148 return device.deviceState() == deviceStateUp 149 } 150 151 // Must hold device.peers.Lock() 152 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 153 // stop routing and processing of packets 154 device.allowedips.RemoveByPeer(peer) 155 peer.Stop() 156 157 // remove from peer map 158 delete(device.peers.keyMap, key) 159 } 160 161 // changeState attempts to change the device state to match want. 162 func (device *Device) changeState(want deviceState) (err error) { 163 device.state.Lock() 164 defer device.state.Unlock() 165 old := device.deviceState() 166 if old == deviceStateClosed { 167 // once closed, always closed 168 device.log.Verbosef("Interface closed, ignored requested state %s", want) 169 return nil 170 } 171 switch want { 172 case old: 173 return nil 174 case deviceStateUp: 175 device.state.state.Store(uint32(deviceStateUp)) 176 err = device.upLocked() 177 if err == nil { 178 break 179 } 180 fallthrough // up failed; bring the device all the way back down 181 case deviceStateDown: 182 device.state.state.Store(uint32(deviceStateDown)) 183 errDown := device.downLocked() 184 if err == nil { 185 err = errDown 186 } 187 } 188 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 189 return 190 } 191 192 // upLocked attempts to bring the device up and reports whether it succeeded. 193 // The caller must hold device.state.mu and is responsible for updating device.state.state. 194 func (device *Device) upLocked() error { 195 if err := device.BindUpdate(); err != nil { 196 device.log.Errorf("Unable to update bind: %v", err) 197 return err 198 } 199 200 // The IPC set operation waits for peers to be created before calling Start() on them, 201 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 202 device.ipcMutex.Lock() 203 defer device.ipcMutex.Unlock() 204 205 device.peers.RLock() 206 for _, peer := range device.peers.keyMap { 207 peer.Start() 208 if peer.persistentKeepaliveInterval.Load() > 0 { 209 peer.SendKeepalive() 210 } 211 } 212 device.peers.RUnlock() 213 return nil 214 } 215 216 // downLocked attempts to bring the device down. 217 // The caller must hold device.state.mu and is responsible for updating device.state.state. 218 func (device *Device) downLocked() error { 219 err := device.BindClose() 220 if err != nil { 221 device.log.Errorf("Bind close failed: %v", err) 222 } 223 224 device.peers.RLock() 225 for _, peer := range device.peers.keyMap { 226 peer.Stop() 227 } 228 device.peers.RUnlock() 229 return err 230 } 231 232 func (device *Device) Up() error { 233 return device.changeState(deviceStateUp) 234 } 235 236 func (device *Device) Down() error { 237 return device.changeState(deviceStateDown) 238 } 239 240 func (device *Device) IsUnderLoad() bool { 241 // check if currently under load 242 now := time.Now() 243 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 244 if underLoad { 245 device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) 246 return true 247 } 248 // check if recently under load 249 return device.rate.underLoadUntil.Load() > now.UnixNano() 250 } 251 252 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 253 // lock required resources 254 255 device.staticIdentity.Lock() 256 defer device.staticIdentity.Unlock() 257 258 if sk.Equals(device.staticIdentity.privateKey) { 259 return nil 260 } 261 262 device.peers.Lock() 263 defer device.peers.Unlock() 264 265 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 266 for _, peer := range device.peers.keyMap { 267 peer.handshake.mutex.RLock() 268 lockedPeers = append(lockedPeers, peer) 269 } 270 271 // remove peers with matching public keys 272 273 publicKey := sk.publicKey() 274 for key, peer := range device.peers.keyMap { 275 if peer.handshake.remoteStatic.Equals(publicKey) { 276 peer.handshake.mutex.RUnlock() 277 removePeerLocked(device, peer, key) 278 peer.handshake.mutex.RLock() 279 } 280 } 281 282 // update key material 283 284 device.staticIdentity.privateKey = sk 285 device.staticIdentity.publicKey = publicKey 286 device.cookieChecker.Init(publicKey) 287 288 // do static-static DH pre-computations 289 290 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 291 for _, peer := range device.peers.keyMap { 292 handshake := &peer.handshake 293 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 294 expiredPeers = append(expiredPeers, peer) 295 } 296 297 for _, peer := range lockedPeers { 298 peer.handshake.mutex.RUnlock() 299 } 300 for _, peer := range expiredPeers { 301 peer.ExpireCurrentKeypairs() 302 } 303 304 return nil 305 } 306 307 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 308 device := new(Device) 309 device.state.state.Store(uint32(deviceStateDown)) 310 device.closed = make(chan struct{}) 311 device.log = logger 312 device.net.bind = bind 313 device.tun.device = tunDevice 314 mtu, err := device.tun.device.MTU() 315 if err != nil { 316 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 317 mtu = DefaultMTU 318 } 319 device.tun.mtu.Store(int32(mtu)) 320 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 321 device.rate.limiter.Init() 322 device.indexTable.Init() 323 324 device.PopulatePools() 325 326 // create queues 327 328 device.queue.handshake = newHandshakeQueue() 329 device.queue.encryption = newOutboundQueue() 330 device.queue.decryption = newInboundQueue() 331 332 // start workers 333 334 cpus := runtime.NumCPU() 335 device.state.stopping.Wait() 336 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 337 for i := 0; i < cpus; i++ { 338 go device.RoutineEncryption(i + 1) 339 go device.RoutineDecryption(i + 1) 340 go device.RoutineHandshake(i + 1) 341 } 342 343 device.state.stopping.Add(1) // RoutineReadFromTUN 344 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 345 go device.RoutineReadFromTUN() 346 go device.RoutineTUNEventReader() 347 348 return device 349 } 350 351 // BatchSize returns the BatchSize for the device as a whole which is the max of 352 // the bind batch size and the tun batch size. The batch size reported by device 353 // is the size used to construct memory pools, and is the allowed batch size for 354 // the lifetime of the device. 355 func (device *Device) BatchSize() int { 356 size := device.net.bind.BatchSize() 357 dSize := device.tun.device.BatchSize() 358 if size < dSize { 359 size = dSize 360 } 361 return size 362 } 363 364 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 365 device.peers.RLock() 366 defer device.peers.RUnlock() 367 368 return device.peers.keyMap[pk] 369 } 370 371 func (device *Device) RemovePeer(key NoisePublicKey) { 372 device.peers.Lock() 373 defer device.peers.Unlock() 374 // stop peer and remove from routing 375 376 peer, ok := device.peers.keyMap[key] 377 if ok { 378 removePeerLocked(device, peer, key) 379 } 380 } 381 382 func (device *Device) RemoveAllPeers() { 383 device.peers.Lock() 384 defer device.peers.Unlock() 385 386 for key, peer := range device.peers.keyMap { 387 removePeerLocked(device, peer, key) 388 } 389 390 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 391 } 392 393 func (device *Device) Close() { 394 device.state.Lock() 395 defer device.state.Unlock() 396 device.ipcMutex.Lock() 397 defer device.ipcMutex.Unlock() 398 if device.isClosed() { 399 return 400 } 401 device.state.state.Store(uint32(deviceStateClosed)) 402 device.log.Verbosef("Device closing") 403 404 device.tun.device.Close() 405 device.downLocked() 406 407 // Remove peers before closing queues, 408 // because peers assume that queues are active. 409 device.RemoveAllPeers() 410 411 // We kept a reference to the encryption and decryption queues, 412 // in case we started any new peers that might write to them. 413 // No new peers are coming; we are done with these queues. 414 device.queue.encryption.wg.Done() 415 device.queue.decryption.wg.Done() 416 device.queue.handshake.wg.Done() 417 device.state.stopping.Wait() 418 419 device.rate.limiter.Close() 420 421 device.log.Verbosef("Device closed") 422 close(device.closed) 423 } 424 425 func (device *Device) Wait() chan struct{} { 426 return device.closed 427 } 428 429 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 430 if !device.isUp() { 431 return 432 } 433 434 device.peers.RLock() 435 for _, peer := range device.peers.keyMap { 436 peer.keypairs.RLock() 437 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 438 peer.keypairs.RUnlock() 439 if sendKeepalive { 440 peer.SendKeepalive() 441 } 442 } 443 device.peers.RUnlock() 444 } 445 446 // closeBindLocked closes the device's net.bind. 447 // The caller must hold the net mutex. 448 func closeBindLocked(device *Device) error { 449 var err error 450 netc := &device.net 451 if netc.netlinkCancel != nil { 452 netc.netlinkCancel.Cancel() 453 } 454 if netc.bind != nil { 455 err = netc.bind.Close() 456 } 457 netc.stopping.Wait() 458 return err 459 } 460 461 func (device *Device) Bind() conn.Bind { 462 device.net.Lock() 463 defer device.net.Unlock() 464 return device.net.bind 465 } 466 467 func (device *Device) BindSetMark(mark uint32) error { 468 device.net.Lock() 469 defer device.net.Unlock() 470 471 // check if modified 472 if device.net.fwmark == mark { 473 return nil 474 } 475 476 // update fwmark on existing bind 477 device.net.fwmark = mark 478 if device.isUp() && device.net.bind != nil { 479 if err := device.net.bind.SetMark(mark); err != nil { 480 return err 481 } 482 } 483 484 // clear cached source addresses 485 device.peers.RLock() 486 for _, peer := range device.peers.keyMap { 487 peer.markEndpointSrcForClearing() 488 } 489 device.peers.RUnlock() 490 491 return nil 492 } 493 494 func (device *Device) BindUpdate() error { 495 device.net.Lock() 496 defer device.net.Unlock() 497 498 // close existing sockets 499 if err := closeBindLocked(device); err != nil { 500 return err 501 } 502 503 // open new sockets 504 if !device.isUp() { 505 return nil 506 } 507 508 // bind to new port 509 var err error 510 var recvFns []conn.ReceiveFunc 511 netc := &device.net 512 513 recvFns, netc.port, err = netc.bind.Open(netc.ipv4Addr, netc.ipv6Addr, netc.port) 514 if err != nil { 515 netc.port = 0 516 return err 517 } 518 519 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 520 if err != nil { 521 netc.bind.Close() 522 netc.port = 0 523 return err 524 } 525 526 // set fwmark 527 if netc.fwmark != 0 { 528 err = netc.bind.SetMark(netc.fwmark) 529 if err != nil { 530 return err 531 } 532 } 533 534 // clear cached source addresses 535 device.peers.RLock() 536 for _, peer := range device.peers.keyMap { 537 peer.markEndpointSrcForClearing() 538 } 539 device.peers.RUnlock() 540 541 // start receiving routines 542 device.net.stopping.Add(len(recvFns)) 543 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 544 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 545 batchSize := netc.bind.BatchSize() 546 for _, fn := range recvFns { 547 go device.RoutineReceiveIncoming(batchSize, fn) 548 } 549 550 device.log.Verbosef("UDP bind has been updated") 551 return nil 552 } 553 554 func (device *Device) BindClose() error { 555 device.net.Lock() 556 err := closeBindLocked(device) 557 device.net.Unlock() 558 return err 559 }