github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/cawidtu/notwireguard-go/conn" 15 "github.com/cawidtu/notwireguard-go/ratelimiter" 16 "github.com/cawidtu/notwireguard-go/rwcancel" 17 "github.com/cawidtu/notwireguard-go/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 brokenRoaming bool 48 } 49 50 staticIdentity struct { 51 sync.RWMutex 52 privateKey NoisePrivateKey 53 publicKey NoisePublicKey 54 // new 55 obfuscator [NoisePublicKeySize]byte 56 } 57 58 peers struct { 59 sync.RWMutex // protects keyMap 60 keyMap map[NoisePublicKey]*Peer 61 } 62 63 // Keep this 8-byte aligned 64 rate struct { 65 underLoadUntil int64 66 limiter ratelimiter.Ratelimiter 67 } 68 69 allowedips AllowedIPs 70 indexTable IndexTable 71 cookieChecker CookieChecker 72 73 pool struct { 74 messageBuffers *WaitPool 75 inboundElements *WaitPool 76 outboundElements *WaitPool 77 } 78 79 queue struct { 80 encryption *outboundQueue 81 decryption *inboundQueue 82 handshake *handshakeQueue 83 } 84 85 tun struct { 86 device tun.Device 87 mtu int32 88 } 89 90 ipcMutex sync.RWMutex 91 closed chan struct{} 92 log *Logger 93 } 94 95 // deviceState represents the state of a Device. 96 // There are three states: down, up, closed. 97 // Transitions: 98 // 99 // down -----+ 100 // ↑↓ ↓ 101 // up -> closed 102 // 103 type deviceState uint32 104 105 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 106 const ( 107 deviceStateDown deviceState = iota 108 deviceStateUp 109 deviceStateClosed 110 ) 111 112 // deviceState returns device.state.state as a deviceState 113 // See those docs for how to interpret this value. 114 func (device *Device) deviceState() deviceState { 115 return deviceState(atomic.LoadUint32(&device.state.state)) 116 } 117 118 // isClosed reports whether the device is closed (or is closing). 119 // See device.state.state comments for how to interpret this value. 120 func (device *Device) isClosed() bool { 121 return device.deviceState() == deviceStateClosed 122 } 123 124 // isUp reports whether the device is up (or is attempting to come up). 125 // See device.state.state comments for how to interpret this value. 126 func (device *Device) isUp() bool { 127 return device.deviceState() == deviceStateUp 128 } 129 130 // Must hold device.peers.Lock() 131 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 132 // stop routing and processing of packets 133 device.allowedips.RemoveByPeer(peer) 134 peer.Stop() 135 136 // remove from peer map 137 delete(device.peers.keyMap, key) 138 } 139 140 // changeState attempts to change the device state to match want. 141 func (device *Device) changeState(want deviceState) (err error) { 142 device.state.Lock() 143 defer device.state.Unlock() 144 old := device.deviceState() 145 if old == deviceStateClosed { 146 // once closed, always closed 147 device.log.Verbosef("Interface closed, ignored requested state %s", want) 148 return nil 149 } 150 switch want { 151 case old: 152 return nil 153 case deviceStateUp: 154 atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) 155 err = device.upLocked() 156 if err == nil { 157 break 158 } 159 fallthrough // up failed; bring the device all the way back down 160 case deviceStateDown: 161 atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) 162 errDown := device.downLocked() 163 if err == nil { 164 err = errDown 165 } 166 } 167 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 168 return 169 } 170 171 // upLocked attempts to bring the device up and reports whether it succeeded. 172 // The caller must hold device.state.mu and is responsible for updating device.state.state. 173 func (device *Device) upLocked() error { 174 if err := device.BindUpdate(); err != nil { 175 device.log.Errorf("Unable to update bind: %v", err) 176 return err 177 } 178 179 // The IPC set operation waits for peers to be created before calling Start() on them, 180 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 181 device.ipcMutex.Lock() 182 defer device.ipcMutex.Unlock() 183 184 device.peers.RLock() 185 for _, peer := range device.peers.keyMap { 186 peer.Start() 187 if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 188 peer.SendKeepalive() 189 } 190 } 191 device.peers.RUnlock() 192 return nil 193 } 194 195 // downLocked attempts to bring the device down. 196 // The caller must hold device.state.mu and is responsible for updating device.state.state. 197 func (device *Device) downLocked() error { 198 err := device.BindClose() 199 if err != nil { 200 device.log.Errorf("Bind close failed: %v", err) 201 } 202 203 device.peers.RLock() 204 for _, peer := range device.peers.keyMap { 205 peer.Stop() 206 } 207 device.peers.RUnlock() 208 return err 209 } 210 211 func (device *Device) Up() error { 212 return device.changeState(deviceStateUp) 213 } 214 215 func (device *Device) Down() error { 216 return device.changeState(deviceStateDown) 217 } 218 219 func (device *Device) IsUnderLoad() bool { 220 // check if currently under load 221 now := time.Now() 222 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 223 if underLoad { 224 atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) 225 return true 226 } 227 // check if recently under load 228 return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() 229 } 230 231 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 232 // lock required resources 233 234 device.staticIdentity.Lock() 235 defer device.staticIdentity.Unlock() 236 237 if sk.Equals(device.staticIdentity.privateKey) { 238 return nil 239 } 240 241 device.peers.Lock() 242 defer device.peers.Unlock() 243 244 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 245 for _, peer := range device.peers.keyMap { 246 peer.handshake.mutex.RLock() 247 lockedPeers = append(lockedPeers, peer) 248 } 249 250 // remove peers with matching public keys 251 252 publicKey := sk.publicKey() 253 for key, peer := range device.peers.keyMap { 254 if peer.handshake.remoteStatic.Equals(publicKey) { 255 peer.handshake.mutex.RUnlock() 256 removePeerLocked(device, peer, key) 257 peer.handshake.mutex.RLock() 258 } 259 } 260 261 // update key material 262 263 device.staticIdentity.privateKey = sk 264 device.staticIdentity.publicKey = publicKey 265 device.staticIdentity.obfuscator = wgNoiseCreateObfuscator(device.staticIdentity.publicKey) 266 device.cookieChecker.Init(publicKey) 267 268 // do static-static DH pre-computations 269 270 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 271 for _, peer := range device.peers.keyMap { 272 handshake := &peer.handshake 273 handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 274 expiredPeers = append(expiredPeers, peer) 275 } 276 277 for _, peer := range lockedPeers { 278 peer.handshake.mutex.RUnlock() 279 } 280 for _, peer := range expiredPeers { 281 peer.ExpireCurrentKeypairs() 282 } 283 284 return nil 285 } 286 287 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 288 device := new(Device) 289 device.state.state = uint32(deviceStateDown) 290 device.closed = make(chan struct{}) 291 device.log = logger 292 device.net.bind = bind 293 device.tun.device = tunDevice 294 mtu, err := device.tun.device.MTU() 295 if err != nil { 296 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 297 mtu = DefaultMTU 298 } 299 device.tun.mtu = int32(mtu) 300 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 301 device.rate.limiter.Init() 302 device.indexTable.Init() 303 device.PopulatePools() 304 305 // create queues 306 307 device.queue.handshake = newHandshakeQueue() 308 device.queue.encryption = newOutboundQueue() 309 device.queue.decryption = newInboundQueue() 310 311 // start workers 312 313 cpus := runtime.NumCPU() 314 device.state.stopping.Wait() 315 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 316 for i := 0; i < cpus; i++ { 317 go device.RoutineEncryption(i + 1) 318 go device.RoutineDecryption(i + 1) 319 go device.RoutineHandshake(i + 1) 320 } 321 322 device.state.stopping.Add(1) // RoutineReadFromTUN 323 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 324 go device.RoutineReadFromTUN() 325 go device.RoutineTUNEventReader() 326 327 return device 328 } 329 330 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 331 device.peers.RLock() 332 defer device.peers.RUnlock() 333 334 return device.peers.keyMap[pk] 335 } 336 337 func (device *Device) RemovePeer(key NoisePublicKey) { 338 device.peers.Lock() 339 defer device.peers.Unlock() 340 // stop peer and remove from routing 341 342 peer, ok := device.peers.keyMap[key] 343 if ok { 344 removePeerLocked(device, peer, key) 345 } 346 } 347 348 func (device *Device) RemoveAllPeers() { 349 device.peers.Lock() 350 defer device.peers.Unlock() 351 352 for key, peer := range device.peers.keyMap { 353 removePeerLocked(device, peer, key) 354 } 355 356 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 357 } 358 359 func (device *Device) Close() { 360 device.state.Lock() 361 defer device.state.Unlock() 362 if device.isClosed() { 363 return 364 } 365 atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) 366 device.log.Verbosef("Device closing") 367 368 device.tun.device.Close() 369 device.downLocked() 370 371 // Remove peers before closing queues, 372 // because peers assume that queues are active. 373 device.RemoveAllPeers() 374 375 // We kept a reference to the encryption and decryption queues, 376 // in case we started any new peers that might write to them. 377 // No new peers are coming; we are done with these queues. 378 device.queue.encryption.wg.Done() 379 device.queue.decryption.wg.Done() 380 device.queue.handshake.wg.Done() 381 device.state.stopping.Wait() 382 383 device.rate.limiter.Close() 384 385 device.log.Verbosef("Device closed") 386 close(device.closed) 387 } 388 389 func (device *Device) Wait() chan struct{} { 390 return device.closed 391 } 392 393 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 394 if !device.isUp() { 395 return 396 } 397 398 device.peers.RLock() 399 for _, peer := range device.peers.keyMap { 400 peer.keypairs.RLock() 401 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 402 peer.keypairs.RUnlock() 403 if sendKeepalive { 404 peer.SendKeepalive() 405 } 406 } 407 device.peers.RUnlock() 408 } 409 410 // closeBindLocked closes the device's net.bind. 411 // The caller must hold the net mutex. 412 func closeBindLocked(device *Device) error { 413 var err error 414 netc := &device.net 415 if netc.netlinkCancel != nil { 416 netc.netlinkCancel.Cancel() 417 } 418 if netc.bind != nil { 419 err = netc.bind.Close() 420 } 421 netc.stopping.Wait() 422 return err 423 } 424 425 func (device *Device) Bind() conn.Bind { 426 device.net.Lock() 427 defer device.net.Unlock() 428 return device.net.bind 429 } 430 431 func (device *Device) BindSetMark(mark uint32) error { 432 device.net.Lock() 433 defer device.net.Unlock() 434 435 // check if modified 436 if device.net.fwmark == mark { 437 return nil 438 } 439 440 // update fwmark on existing bind 441 device.net.fwmark = mark 442 if device.isUp() && device.net.bind != nil { 443 if err := device.net.bind.SetMark(mark); err != nil { 444 return err 445 } 446 } 447 448 // clear cached source addresses 449 device.peers.RLock() 450 for _, peer := range device.peers.keyMap { 451 peer.Lock() 452 defer peer.Unlock() 453 if peer.endpoint != nil { 454 peer.endpoint.ClearSrc() 455 } 456 } 457 device.peers.RUnlock() 458 459 return nil 460 } 461 462 func (device *Device) BindUpdate() error { 463 device.net.Lock() 464 defer device.net.Unlock() 465 466 // close existing sockets 467 if err := closeBindLocked(device); err != nil { 468 return err 469 } 470 471 // open new sockets 472 if !device.isUp() { 473 return nil 474 } 475 476 // bind to new port 477 var err error 478 var recvFns []conn.ReceiveFunc 479 netc := &device.net 480 recvFns, netc.port, err = netc.bind.Open(netc.port) 481 if err != nil { 482 netc.port = 0 483 return err 484 } 485 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 486 if err != nil { 487 netc.bind.Close() 488 netc.port = 0 489 return err 490 } 491 492 // set fwmark 493 if netc.fwmark != 0 { 494 err = netc.bind.SetMark(netc.fwmark) 495 if err != nil { 496 return err 497 } 498 } 499 500 // clear cached source addresses 501 device.peers.RLock() 502 for _, peer := range device.peers.keyMap { 503 peer.Lock() 504 defer peer.Unlock() 505 if peer.endpoint != nil { 506 peer.endpoint.ClearSrc() 507 } 508 } 509 device.peers.RUnlock() 510 511 // start receiving routines 512 device.net.stopping.Add(len(recvFns)) 513 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 514 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 515 for _, fn := range recvFns { 516 go device.RoutineReceiveIncoming(fn) 517 } 518 519 device.log.Verbosef("UDP bind has been updated") 520 return nil 521 } 522 523 func (device *Device) BindClose() error { 524 device.net.Lock() 525 err := closeBindLocked(device) 526 device.net.Unlock() 527 return err 528 }