github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/liloew/wireguard-go/conn" 15 "github.com/liloew/wireguard-go/ratelimiter" 16 "github.com/liloew/wireguard-go/rwcancel" 17 "github.com/liloew/wireguard-go/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 brokenRoaming bool 48 } 49 50 staticIdentity struct { 51 sync.RWMutex 52 privateKey NoisePrivateKey 53 publicKey NoisePublicKey 54 } 55 56 peers struct { 57 sync.RWMutex // protects keyMap 58 keyMap map[NoisePublicKey]*Peer 59 } 60 61 // Keep this 8-byte aligned 62 rate struct { 63 underLoadUntil int64 64 limiter ratelimiter.Ratelimiter 65 } 66 67 allowedips AllowedIPs 68 indexTable IndexTable 69 cookieChecker CookieChecker 70 71 pool struct { 72 messageBuffers *WaitPool 73 inboundElements *WaitPool 74 outboundElements *WaitPool 75 } 76 77 queue struct { 78 encryption *outboundQueue 79 decryption *inboundQueue 80 handshake *handshakeQueue 81 } 82 83 tun struct { 84 device tun.Device 85 mtu int32 86 } 87 88 ipcMutex sync.RWMutex 89 closed chan struct{} 90 log *Logger 91 } 92 93 // deviceState represents the state of a Device. 94 // There are three states: down, up, closed. 95 // Transitions: 96 // 97 // down -----+ 98 // ↑↓ ↓ 99 // up -> closed 100 // 101 type deviceState uint32 102 103 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 104 const ( 105 deviceStateDown deviceState = iota 106 deviceStateUp 107 deviceStateClosed 108 ) 109 110 // deviceState returns device.state.state as a deviceState 111 // See those docs for how to interpret this value. 112 func (device *Device) deviceState() deviceState { 113 return deviceState(atomic.LoadUint32(&device.state.state)) 114 } 115 116 // isClosed reports whether the device is closed (or is closing). 117 // See device.state.state comments for how to interpret this value. 118 func (device *Device) isClosed() bool { 119 return device.deviceState() == deviceStateClosed 120 } 121 122 // isUp reports whether the device is up (or is attempting to come up). 123 // See device.state.state comments for how to interpret this value. 124 func (device *Device) isUp() bool { 125 return device.deviceState() == deviceStateUp 126 } 127 128 // Must hold device.peers.Lock() 129 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 130 // stop routing and processing of packets 131 device.allowedips.RemoveByPeer(peer) 132 peer.Stop() 133 134 // remove from peer map 135 delete(device.peers.keyMap, key) 136 } 137 138 // changeState attempts to change the device state to match want. 139 func (device *Device) changeState(want deviceState) (err error) { 140 device.state.Lock() 141 defer device.state.Unlock() 142 old := device.deviceState() 143 if old == deviceStateClosed { 144 // once closed, always closed 145 device.log.Verbosef("Interface closed, ignored requested state %s", want) 146 return nil 147 } 148 switch want { 149 case old: 150 return nil 151 case deviceStateUp: 152 atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) 153 err = device.upLocked() 154 if err == nil { 155 break 156 } 157 fallthrough // up failed; bring the device all the way back down 158 case deviceStateDown: 159 atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) 160 errDown := device.downLocked() 161 if err == nil { 162 err = errDown 163 } 164 } 165 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 166 return 167 } 168 169 // upLocked attempts to bring the device up and reports whether it succeeded. 170 // The caller must hold device.state.mu and is responsible for updating device.state.state. 171 func (device *Device) upLocked() error { 172 if err := device.BindUpdate(); err != nil { 173 device.log.Errorf("Unable to update bind: %v", err) 174 return err 175 } 176 177 // The IPC set operation waits for peers to be created before calling Start() on them, 178 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 179 device.ipcMutex.Lock() 180 defer device.ipcMutex.Unlock() 181 182 device.peers.RLock() 183 for _, peer := range device.peers.keyMap { 184 peer.Start() 185 if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 186 peer.SendKeepalive() 187 } 188 } 189 device.peers.RUnlock() 190 return nil 191 } 192 193 // downLocked attempts to bring the device down. 194 // The caller must hold device.state.mu and is responsible for updating device.state.state. 195 func (device *Device) downLocked() error { 196 err := device.BindClose() 197 if err != nil { 198 device.log.Errorf("Bind close failed: %v", err) 199 } 200 201 device.peers.RLock() 202 for _, peer := range device.peers.keyMap { 203 peer.Stop() 204 } 205 device.peers.RUnlock() 206 return err 207 } 208 209 func (device *Device) Up() error { 210 return device.changeState(deviceStateUp) 211 } 212 213 func (device *Device) Down() error { 214 return device.changeState(deviceStateDown) 215 } 216 217 func (device *Device) IsUnderLoad() bool { 218 // check if currently under load 219 now := time.Now() 220 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 221 if underLoad { 222 atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) 223 return true 224 } 225 // check if recently under load 226 return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() 227 } 228 229 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 230 // lock required resources 231 232 device.staticIdentity.Lock() 233 defer device.staticIdentity.Unlock() 234 235 if sk.Equals(device.staticIdentity.privateKey) { 236 return nil 237 } 238 239 device.peers.Lock() 240 defer device.peers.Unlock() 241 242 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 243 for _, peer := range device.peers.keyMap { 244 peer.handshake.mutex.RLock() 245 lockedPeers = append(lockedPeers, peer) 246 } 247 248 // remove peers with matching public keys 249 250 publicKey := sk.publicKey() 251 for key, peer := range device.peers.keyMap { 252 if peer.handshake.remoteStatic.Equals(publicKey) { 253 peer.handshake.mutex.RUnlock() 254 removePeerLocked(device, peer, key) 255 peer.handshake.mutex.RLock() 256 } 257 } 258 259 // update key material 260 261 device.staticIdentity.privateKey = sk 262 device.staticIdentity.publicKey = publicKey 263 device.cookieChecker.Init(publicKey) 264 265 // do static-static DH pre-computations 266 267 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 268 for _, peer := range device.peers.keyMap { 269 handshake := &peer.handshake 270 handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 271 expiredPeers = append(expiredPeers, peer) 272 } 273 274 for _, peer := range lockedPeers { 275 peer.handshake.mutex.RUnlock() 276 } 277 for _, peer := range expiredPeers { 278 peer.ExpireCurrentKeypairs() 279 } 280 281 return nil 282 } 283 284 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 285 device := new(Device) 286 device.state.state = uint32(deviceStateDown) 287 device.closed = make(chan struct{}) 288 device.log = logger 289 device.net.bind = bind 290 device.tun.device = tunDevice 291 mtu, err := device.tun.device.MTU() 292 if err != nil { 293 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 294 mtu = DefaultMTU 295 } 296 device.tun.mtu = int32(mtu) 297 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 298 device.rate.limiter.Init() 299 device.indexTable.Init() 300 device.PopulatePools() 301 302 // create queues 303 304 device.queue.handshake = newHandshakeQueue() 305 device.queue.encryption = newOutboundQueue() 306 device.queue.decryption = newInboundQueue() 307 308 // start workers 309 310 cpus := runtime.NumCPU() 311 device.state.stopping.Wait() 312 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 313 for i := 0; i < cpus; i++ { 314 go device.RoutineEncryption(i + 1) 315 go device.RoutineDecryption(i + 1) 316 go device.RoutineHandshake(i + 1) 317 } 318 319 device.state.stopping.Add(1) // RoutineReadFromTUN 320 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 321 go device.RoutineReadFromTUN() 322 go device.RoutineTUNEventReader() 323 324 return device 325 } 326 327 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 328 device.peers.RLock() 329 defer device.peers.RUnlock() 330 331 return device.peers.keyMap[pk] 332 } 333 334 func (device *Device) RemovePeer(key NoisePublicKey) { 335 device.peers.Lock() 336 defer device.peers.Unlock() 337 // stop peer and remove from routing 338 339 peer, ok := device.peers.keyMap[key] 340 if ok { 341 removePeerLocked(device, peer, key) 342 } 343 } 344 345 func (device *Device) RemoveAllPeers() { 346 device.peers.Lock() 347 defer device.peers.Unlock() 348 349 for key, peer := range device.peers.keyMap { 350 removePeerLocked(device, peer, key) 351 } 352 353 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 354 } 355 356 func (device *Device) Close() { 357 device.state.Lock() 358 defer device.state.Unlock() 359 if device.isClosed() { 360 return 361 } 362 atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) 363 device.log.Verbosef("Device closing") 364 365 device.tun.device.Close() 366 device.downLocked() 367 368 // Remove peers before closing queues, 369 // because peers assume that queues are active. 370 device.RemoveAllPeers() 371 372 // We kept a reference to the encryption and decryption queues, 373 // in case we started any new peers that might write to them. 374 // No new peers are coming; we are done with these queues. 375 device.queue.encryption.wg.Done() 376 device.queue.decryption.wg.Done() 377 device.queue.handshake.wg.Done() 378 device.state.stopping.Wait() 379 380 device.rate.limiter.Close() 381 382 device.log.Verbosef("Device closed") 383 close(device.closed) 384 } 385 386 func (device *Device) Wait() chan struct{} { 387 return device.closed 388 } 389 390 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 391 if !device.isUp() { 392 return 393 } 394 395 device.peers.RLock() 396 for _, peer := range device.peers.keyMap { 397 peer.keypairs.RLock() 398 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 399 peer.keypairs.RUnlock() 400 if sendKeepalive { 401 peer.SendKeepalive() 402 } 403 } 404 device.peers.RUnlock() 405 } 406 407 // closeBindLocked closes the device's net.bind. 408 // The caller must hold the net mutex. 409 func closeBindLocked(device *Device) error { 410 var err error 411 netc := &device.net 412 if netc.netlinkCancel != nil { 413 netc.netlinkCancel.Cancel() 414 } 415 if netc.bind != nil { 416 err = netc.bind.Close() 417 } 418 netc.stopping.Wait() 419 return err 420 } 421 422 func (device *Device) Bind() conn.Bind { 423 device.net.Lock() 424 defer device.net.Unlock() 425 return device.net.bind 426 } 427 428 func (device *Device) BindSetMark(mark uint32) error { 429 device.net.Lock() 430 defer device.net.Unlock() 431 432 // check if modified 433 if device.net.fwmark == mark { 434 return nil 435 } 436 437 // update fwmark on existing bind 438 device.net.fwmark = mark 439 if device.isUp() && device.net.bind != nil { 440 if err := device.net.bind.SetMark(mark); err != nil { 441 return err 442 } 443 } 444 445 // clear cached source addresses 446 device.peers.RLock() 447 for _, peer := range device.peers.keyMap { 448 peer.Lock() 449 defer peer.Unlock() 450 if peer.endpoint != nil { 451 peer.endpoint.ClearSrc() 452 } 453 } 454 device.peers.RUnlock() 455 456 return nil 457 } 458 459 func (device *Device) BindUpdate() error { 460 device.net.Lock() 461 defer device.net.Unlock() 462 463 // close existing sockets 464 if err := closeBindLocked(device); err != nil { 465 return err 466 } 467 468 // open new sockets 469 if !device.isUp() { 470 return nil 471 } 472 473 // bind to new port 474 var err error 475 var recvFns []conn.ReceiveFunc 476 netc := &device.net 477 recvFns, netc.port, err = netc.bind.Open(netc.port) 478 if err != nil { 479 netc.port = 0 480 return err 481 } 482 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 483 if err != nil { 484 netc.bind.Close() 485 netc.port = 0 486 return err 487 } 488 489 // set fwmark 490 if netc.fwmark != 0 { 491 err = netc.bind.SetMark(netc.fwmark) 492 if err != nil { 493 return err 494 } 495 } 496 497 // clear cached source addresses 498 device.peers.RLock() 499 for _, peer := range device.peers.keyMap { 500 peer.Lock() 501 defer peer.Unlock() 502 if peer.endpoint != nil { 503 peer.endpoint.ClearSrc() 504 } 505 } 506 device.peers.RUnlock() 507 508 // start receiving routines 509 device.net.stopping.Add(len(recvFns)) 510 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 511 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 512 for _, fn := range recvFns { 513 go device.RoutineReceiveIncoming(fn) 514 } 515 516 device.log.Verbosef("UDP bind has been updated") 517 return nil 518 } 519 520 func (device *Device) BindClose() error { 521 device.net.Lock() 522 err := closeBindLocked(device) 523 device.net.Unlock() 524 return err 525 }