github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/bepass-org/wireguard-go/conn" 15 "github.com/bepass-org/wireguard-go/ratelimiter" 16 "github.com/bepass-org/wireguard-go/rwcancel" 17 "github.com/bepass-org/wireguard-go/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 brokenRoaming bool 48 } 49 50 staticIdentity struct { 51 sync.RWMutex 52 privateKey NoisePrivateKey 53 publicKey NoisePublicKey 54 } 55 56 peers struct { 57 sync.RWMutex // protects keyMap 58 keyMap map[NoisePublicKey]*Peer 59 } 60 61 rate struct { 62 underLoadUntil atomic.Int64 63 limiter ratelimiter.Ratelimiter 64 } 65 66 allowedips AllowedIPs 67 indexTable IndexTable 68 cookieChecker CookieChecker 69 70 pool struct { 71 inboundElementsContainer *WaitPool 72 outboundElementsContainer *WaitPool 73 messageBuffers *WaitPool 74 inboundElements *WaitPool 75 outboundElements *WaitPool 76 } 77 78 queue struct { 79 encryption *outboundQueue 80 decryption *inboundQueue 81 handshake *handshakeQueue 82 } 83 84 tun struct { 85 device tun.Device 86 mtu atomic.Int32 87 } 88 89 trick bool 90 91 ipcMutex sync.RWMutex 92 closed chan struct{} 93 log *Logger 94 } 95 96 // deviceState represents the state of a Device. 97 // There are three states: down, up, closed. 98 // Transitions: 99 // 100 // down -----+ 101 // ↑↓ ↓ 102 // up -> closed 103 type deviceState uint32 104 105 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 106 const ( 107 deviceStateDown deviceState = iota 108 deviceStateUp 109 deviceStateClosed 110 ) 111 112 // deviceState returns device.state.state as a deviceState 113 // See those docs for how to interpret this value. 114 func (device *Device) deviceState() deviceState { 115 return deviceState(device.state.state.Load()) 116 } 117 118 // isClosed reports whether the device is closed (or is closing). 119 // See device.state.state comments for how to interpret this value. 120 func (device *Device) isClosed() bool { 121 return device.deviceState() == deviceStateClosed 122 } 123 124 // isUp reports whether the device is up (or is attempting to come up). 125 // See device.state.state comments for how to interpret this value. 126 func (device *Device) isUp() bool { 127 return device.deviceState() == deviceStateUp 128 } 129 130 // Must hold device.peers.Lock() 131 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 132 // stop routing and processing of packets 133 device.allowedips.RemoveByPeer(peer) 134 peer.Stop() 135 136 // remove from peer map 137 delete(device.peers.keyMap, key) 138 } 139 140 // changeState attempts to change the device state to match want. 141 func (device *Device) changeState(want deviceState) (err error) { 142 device.state.Lock() 143 defer device.state.Unlock() 144 old := device.deviceState() 145 if old == deviceStateClosed { 146 // once closed, always closed 147 device.log.Verbosef("Interface closed, ignored requested state %s", want) 148 return nil 149 } 150 switch want { 151 case old: 152 return nil 153 case deviceStateUp: 154 device.state.state.Store(uint32(deviceStateUp)) 155 err = device.upLocked() 156 if err == nil { 157 break 158 } 159 fallthrough // up failed; bring the device all the way back down 160 case deviceStateDown: 161 device.state.state.Store(uint32(deviceStateDown)) 162 errDown := device.downLocked() 163 if err == nil { 164 err = errDown 165 } 166 } 167 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 168 return 169 } 170 171 // upLocked attempts to bring the device up and reports whether it succeeded. 172 // The caller must hold device.state.mu and is responsible for updating device.state.state. 173 func (device *Device) upLocked() error { 174 if err := device.BindUpdate(); err != nil { 175 device.log.Errorf("Unable to update bind: %v", err) 176 return err 177 } 178 179 // The IPC set operation waits for peers to be created before calling Start() on them, 180 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 181 device.ipcMutex.Lock() 182 defer device.ipcMutex.Unlock() 183 184 device.peers.RLock() 185 for _, peer := range device.peers.keyMap { 186 peer.Start() 187 if peer.persistentKeepaliveInterval.Load() > 0 { 188 peer.SendKeepalive() 189 } 190 } 191 device.peers.RUnlock() 192 return nil 193 } 194 195 // downLocked attempts to bring the device down. 196 // The caller must hold device.state.mu and is responsible for updating device.state.state. 197 func (device *Device) downLocked() error { 198 err := device.BindClose() 199 if err != nil { 200 device.log.Errorf("Bind close failed: %v", err) 201 } 202 203 device.peers.RLock() 204 for _, peer := range device.peers.keyMap { 205 peer.Stop() 206 } 207 device.peers.RUnlock() 208 return err 209 } 210 211 func (device *Device) Up() error { 212 return device.changeState(deviceStateUp) 213 } 214 215 func (device *Device) Down() error { 216 return device.changeState(deviceStateDown) 217 } 218 219 func (device *Device) IsUnderLoad() bool { 220 // check if currently under load 221 now := time.Now() 222 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 223 if underLoad { 224 device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) 225 return true 226 } 227 // check if recently under load 228 return device.rate.underLoadUntil.Load() > now.UnixNano() 229 } 230 231 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 232 // lock required resources 233 234 device.staticIdentity.Lock() 235 defer device.staticIdentity.Unlock() 236 237 if sk.Equals(device.staticIdentity.privateKey) { 238 return nil 239 } 240 241 device.peers.Lock() 242 defer device.peers.Unlock() 243 244 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 245 for _, peer := range device.peers.keyMap { 246 peer.handshake.mutex.RLock() 247 lockedPeers = append(lockedPeers, peer) 248 } 249 250 // remove peers with matching public keys 251 252 publicKey := sk.publicKey() 253 for key, peer := range device.peers.keyMap { 254 if peer.handshake.remoteStatic.Equals(publicKey) { 255 peer.handshake.mutex.RUnlock() 256 removePeerLocked(device, peer, key) 257 peer.handshake.mutex.RLock() 258 } 259 } 260 261 // update key material 262 263 device.staticIdentity.privateKey = sk 264 device.staticIdentity.publicKey = publicKey 265 device.cookieChecker.Init(publicKey) 266 267 // do static-static DH pre-computations 268 269 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 270 for _, peer := range device.peers.keyMap { 271 handshake := &peer.handshake 272 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 273 expiredPeers = append(expiredPeers, peer) 274 } 275 276 for _, peer := range lockedPeers { 277 peer.handshake.mutex.RUnlock() 278 } 279 for _, peer := range expiredPeers { 280 peer.ExpireCurrentKeypairs() 281 } 282 283 return nil 284 } 285 286 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, trick bool) *Device { 287 device := new(Device) 288 device.trick = trick 289 device.state.state.Store(uint32(deviceStateDown)) 290 device.closed = make(chan struct{}) 291 device.log = logger 292 device.net.bind = bind 293 device.tun.device = tunDevice 294 mtu, err := device.tun.device.MTU() 295 if err != nil { 296 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 297 mtu = DefaultMTU 298 } 299 device.tun.mtu.Store(int32(mtu)) 300 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 301 device.rate.limiter.Init() 302 device.indexTable.Init() 303 304 device.PopulatePools() 305 306 // create queues 307 308 device.queue.handshake = newHandshakeQueue() 309 device.queue.encryption = newOutboundQueue() 310 device.queue.decryption = newInboundQueue() 311 312 // start workers 313 314 cpus := runtime.NumCPU() 315 device.state.stopping.Wait() 316 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 317 for i := 0; i < cpus; i++ { 318 go device.RoutineEncryption(i + 1) 319 go device.RoutineDecryption(i + 1) 320 go device.RoutineHandshake(i + 1) 321 } 322 323 device.state.stopping.Add(1) // RoutineReadFromTUN 324 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 325 go device.RoutineReadFromTUN() 326 go device.RoutineTUNEventReader() 327 328 return device 329 } 330 331 // BatchSize returns the BatchSize for the device as a whole which is the max of 332 // the bind batch size and the tun batch size. The batch size reported by device 333 // is the size used to construct memory pools, and is the allowed batch size for 334 // the lifetime of the device. 335 func (device *Device) BatchSize() int { 336 size := device.net.bind.BatchSize() 337 dSize := device.tun.device.BatchSize() 338 if size < dSize { 339 size = dSize 340 } 341 return size 342 } 343 344 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 345 device.peers.RLock() 346 defer device.peers.RUnlock() 347 348 return device.peers.keyMap[pk] 349 } 350 351 func (device *Device) RemovePeer(key NoisePublicKey) { 352 device.peers.Lock() 353 defer device.peers.Unlock() 354 // stop peer and remove from routing 355 356 peer, ok := device.peers.keyMap[key] 357 if ok { 358 removePeerLocked(device, peer, key) 359 } 360 } 361 362 func (device *Device) RemoveAllPeers() { 363 device.peers.Lock() 364 defer device.peers.Unlock() 365 366 for key, peer := range device.peers.keyMap { 367 removePeerLocked(device, peer, key) 368 } 369 370 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 371 } 372 373 func (device *Device) Close() { 374 device.state.Lock() 375 defer device.state.Unlock() 376 device.ipcMutex.Lock() 377 defer device.ipcMutex.Unlock() 378 if device.isClosed() { 379 return 380 } 381 device.state.state.Store(uint32(deviceStateClosed)) 382 device.log.Verbosef("Device closing") 383 384 device.tun.device.Close() 385 device.downLocked() 386 387 // Remove peers before closing queues, 388 // because peers assume that queues are active. 389 device.RemoveAllPeers() 390 391 // We kept a reference to the encryption and decryption queues, 392 // in case we started any new peers that might write to them. 393 // No new peers are coming; we are done with these queues. 394 device.queue.encryption.wg.Done() 395 device.queue.decryption.wg.Done() 396 device.queue.handshake.wg.Done() 397 device.state.stopping.Wait() 398 399 device.rate.limiter.Close() 400 401 device.log.Verbosef("Device closed") 402 close(device.closed) 403 } 404 405 func (device *Device) Wait() chan struct{} { 406 return device.closed 407 } 408 409 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 410 if !device.isUp() { 411 return 412 } 413 414 device.peers.RLock() 415 for _, peer := range device.peers.keyMap { 416 peer.keypairs.RLock() 417 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 418 peer.keypairs.RUnlock() 419 if sendKeepalive { 420 peer.SendKeepalive() 421 } 422 } 423 device.peers.RUnlock() 424 } 425 426 // closeBindLocked closes the device's net.bind. 427 // The caller must hold the net mutex. 428 func closeBindLocked(device *Device) error { 429 var err error 430 netc := &device.net 431 if netc.netlinkCancel != nil { 432 netc.netlinkCancel.Cancel() 433 } 434 if netc.bind != nil { 435 err = netc.bind.Close() 436 } 437 netc.stopping.Wait() 438 return err 439 } 440 441 func (device *Device) Bind() conn.Bind { 442 device.net.Lock() 443 defer device.net.Unlock() 444 return device.net.bind 445 } 446 447 func (device *Device) BindSetMark(mark uint32) error { 448 device.net.Lock() 449 defer device.net.Unlock() 450 451 // check if modified 452 if device.net.fwmark == mark { 453 return nil 454 } 455 456 // update fwmark on existing bind 457 device.net.fwmark = mark 458 if device.isUp() && device.net.bind != nil { 459 if err := device.net.bind.SetMark(mark); err != nil { 460 return err 461 } 462 } 463 464 // clear cached source addresses 465 device.peers.RLock() 466 for _, peer := range device.peers.keyMap { 467 peer.markEndpointSrcForClearing() 468 } 469 device.peers.RUnlock() 470 471 return nil 472 } 473 474 func (device *Device) BindUpdate() error { 475 device.net.Lock() 476 defer device.net.Unlock() 477 478 // close existing sockets 479 if err := closeBindLocked(device); err != nil { 480 return err 481 } 482 483 // open new sockets 484 if !device.isUp() { 485 return nil 486 } 487 488 // bind to new port 489 var err error 490 var recvFns []conn.ReceiveFunc 491 netc := &device.net 492 493 recvFns, netc.port, err = netc.bind.Open(netc.port) 494 if err != nil { 495 netc.port = 0 496 return err 497 } 498 499 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 500 if err != nil { 501 netc.bind.Close() 502 netc.port = 0 503 return err 504 } 505 506 // set fwmark 507 if netc.fwmark != 0 { 508 err = netc.bind.SetMark(netc.fwmark) 509 if err != nil { 510 return err 511 } 512 } 513 514 // clear cached source addresses 515 device.peers.RLock() 516 for _, peer := range device.peers.keyMap { 517 peer.markEndpointSrcForClearing() 518 } 519 device.peers.RUnlock() 520 521 // start receiving routines 522 device.net.stopping.Add(len(recvFns)) 523 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 524 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 525 batchSize := netc.bind.BatchSize() 526 for _, fn := range recvFns { 527 go device.RoutineReceiveIncoming(batchSize, fn) 528 } 529 530 device.log.Verbosef("UDP bind has been updated") 531 return nil 532 } 533 534 func (device *Device) BindClose() error { 535 device.net.Lock() 536 err := closeBindLocked(device) 537 device.net.Unlock() 538 return err 539 }