github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/tailscale/wireguard-go/conn" 15 "github.com/tailscale/wireguard-go/ratelimiter" 16 "github.com/tailscale/wireguard-go/rwcancel" 17 "github.com/tailscale/wireguard-go/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 } 48 49 staticIdentity struct { 50 sync.RWMutex 51 privateKey NoisePrivateKey 52 publicKey NoisePublicKey 53 } 54 55 rate struct { 56 underLoadUntil int64 57 limiter ratelimiter.Ratelimiter 58 } 59 60 peers struct { 61 sync.RWMutex // protects keyMap 62 keyMap map[NoisePublicKey]*Peer 63 } 64 65 allowedips AllowedIPs 66 indexTable IndexTable 67 cookieChecker CookieChecker 68 69 pool struct { 70 messageBuffers *WaitPool 71 inboundElements *WaitPool 72 outboundElements *WaitPool 73 } 74 75 queue struct { 76 encryption *outboundQueue 77 decryption *inboundQueue 78 handshake *handshakeQueue 79 } 80 81 tun struct { 82 device tun.Device 83 mtu int32 84 } 85 86 ipcMutex sync.RWMutex 87 closed chan struct{} 88 log *Logger 89 } 90 91 // deviceState represents the state of a Device. 92 // There are three states: down, up, closed. 93 // Transitions: 94 // 95 // down -----+ 96 // ↑↓ ↓ 97 // up -> closed 98 // 99 type deviceState uint32 100 101 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 102 const ( 103 deviceStateDown deviceState = iota 104 deviceStateUp 105 deviceStateClosed 106 ) 107 108 // deviceState returns device.state.state as a deviceState 109 // See those docs for how to interpret this value. 110 func (device *Device) deviceState() deviceState { 111 return deviceState(atomic.LoadUint32(&device.state.state)) 112 } 113 114 // isClosed reports whether the device is closed (or is closing). 115 // See device.state.state comments for how to interpret this value. 116 func (device *Device) isClosed() bool { 117 return device.deviceState() == deviceStateClosed 118 } 119 120 // isUp reports whether the device is up (or is attempting to come up). 121 // See device.state.state comments for how to interpret this value. 122 func (device *Device) isUp() bool { 123 return device.deviceState() == deviceStateUp 124 } 125 126 // Must hold device.peers.Lock() 127 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 128 // stop routing and processing of packets 129 device.allowedips.RemoveByPeer(peer) 130 peer.Stop() 131 132 // remove from peer map 133 delete(device.peers.keyMap, key) 134 } 135 136 // changeState attempts to change the device state to match want. 137 func (device *Device) changeState(want deviceState) (err error) { 138 device.state.Lock() 139 defer device.state.Unlock() 140 old := device.deviceState() 141 if old == deviceStateClosed { 142 // once closed, always closed 143 device.log.Verbosef("Interface closed, ignored requested state %s", want) 144 return nil 145 } 146 switch want { 147 case old: 148 return nil 149 case deviceStateUp: 150 atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) 151 err = device.upLocked() 152 if err == nil { 153 break 154 } 155 fallthrough // up failed; bring the device all the way back down 156 case deviceStateDown: 157 atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) 158 errDown := device.downLocked() 159 if err == nil { 160 err = errDown 161 } 162 } 163 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 164 return 165 } 166 167 // upLocked attempts to bring the device up and reports whether it succeeded. 168 // The caller must hold device.state.mu and is responsible for updating device.state.state. 169 func (device *Device) upLocked() error { 170 if err := device.BindUpdate(); err != nil { 171 device.log.Errorf("Unable to update bind: %v", err) 172 return err 173 } 174 175 device.peers.RLock() 176 for _, peer := range device.peers.keyMap { 177 peer.Start() 178 if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 179 peer.SendKeepalive() 180 } 181 } 182 device.peers.RUnlock() 183 return nil 184 } 185 186 // downLocked attempts to bring the device down. 187 // The caller must hold device.state.mu and is responsible for updating device.state.state. 188 func (device *Device) downLocked() error { 189 err := device.BindClose() 190 if err != nil { 191 device.log.Errorf("Bind close failed: %v", err) 192 } 193 194 device.peers.RLock() 195 for _, peer := range device.peers.keyMap { 196 peer.Stop() 197 } 198 device.peers.RUnlock() 199 return err 200 } 201 202 func (device *Device) Up() error { 203 return device.changeState(deviceStateUp) 204 } 205 206 func (device *Device) Down() error { 207 return device.changeState(deviceStateDown) 208 } 209 210 func (device *Device) IsUnderLoad() bool { 211 // check if currently under load 212 now := time.Now() 213 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 214 if underLoad { 215 atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) 216 return true 217 } 218 // check if recently under load 219 return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() 220 } 221 222 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 223 // lock required resources 224 225 device.staticIdentity.Lock() 226 defer device.staticIdentity.Unlock() 227 228 if sk.Equals(device.staticIdentity.privateKey) { 229 return nil 230 } 231 232 device.peers.Lock() 233 defer device.peers.Unlock() 234 235 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 236 for _, peer := range device.peers.keyMap { 237 peer.handshake.mutex.RLock() 238 lockedPeers = append(lockedPeers, peer) 239 } 240 241 // remove peers with matching public keys 242 243 publicKey := sk.publicKey() 244 for key, peer := range device.peers.keyMap { 245 if peer.handshake.remoteStatic.Equals(publicKey) { 246 peer.handshake.mutex.RUnlock() 247 removePeerLocked(device, peer, key) 248 peer.handshake.mutex.RLock() 249 } 250 } 251 252 // update key material 253 254 device.staticIdentity.privateKey = sk 255 device.staticIdentity.publicKey = publicKey 256 device.cookieChecker.Init(publicKey) 257 258 // do static-static DH pre-computations 259 260 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 261 for _, peer := range device.peers.keyMap { 262 handshake := &peer.handshake 263 handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 264 expiredPeers = append(expiredPeers, peer) 265 } 266 267 for _, peer := range lockedPeers { 268 peer.handshake.mutex.RUnlock() 269 } 270 for _, peer := range expiredPeers { 271 peer.ExpireCurrentKeypairs() 272 } 273 274 return nil 275 } 276 277 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 278 device := new(Device) 279 device.state.state = uint32(deviceStateDown) 280 device.closed = make(chan struct{}) 281 device.log = logger 282 device.net.bind = bind 283 device.tun.device = tunDevice 284 mtu, err := device.tun.device.MTU() 285 if err != nil { 286 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 287 mtu = DefaultMTU 288 } 289 device.tun.mtu = int32(mtu) 290 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 291 device.rate.limiter.Init() 292 device.indexTable.Init() 293 device.PopulatePools() 294 295 // create queues 296 297 device.queue.handshake = newHandshakeQueue() 298 device.queue.encryption = newOutboundQueue() 299 device.queue.decryption = newInboundQueue() 300 301 // start workers 302 303 cpus := runtime.NumCPU() 304 device.state.stopping.Wait() 305 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 306 for i := 0; i < cpus; i++ { 307 go device.RoutineEncryption(i + 1) 308 go device.RoutineDecryption(i + 1) 309 go device.RoutineHandshake(i + 1) 310 } 311 312 device.state.stopping.Add(1) // RoutineReadFromTUN 313 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 314 go device.RoutineReadFromTUN() 315 go device.RoutineTUNEventReader() 316 317 return device 318 } 319 320 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 321 device.peers.RLock() 322 defer device.peers.RUnlock() 323 324 return device.peers.keyMap[pk] 325 } 326 327 func (device *Device) RemovePeer(key NoisePublicKey) { 328 device.peers.Lock() 329 defer device.peers.Unlock() 330 // stop peer and remove from routing 331 332 peer, ok := device.peers.keyMap[key] 333 if ok { 334 removePeerLocked(device, peer, key) 335 } 336 } 337 338 func (device *Device) RemoveAllPeers() { 339 device.peers.Lock() 340 defer device.peers.Unlock() 341 342 for key, peer := range device.peers.keyMap { 343 removePeerLocked(device, peer, key) 344 } 345 346 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 347 } 348 349 func (device *Device) Close() { 350 device.state.Lock() 351 defer device.state.Unlock() 352 if device.isClosed() { 353 return 354 } 355 atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) 356 device.log.Verbosef("Device closing") 357 358 device.tun.device.Close() 359 device.downLocked() 360 361 // Remove peers before closing queues, 362 // because peers assume that queues are active. 363 device.RemoveAllPeers() 364 365 // We kept a reference to the encryption and decryption queues, 366 // in case we started any new peers that might write to them. 367 // No new peers are coming; we are done with these queues. 368 device.queue.encryption.wg.Done() 369 device.queue.decryption.wg.Done() 370 device.queue.handshake.wg.Done() 371 device.state.stopping.Wait() 372 373 device.rate.limiter.Close() 374 375 device.log.Verbosef("Device closed") 376 close(device.closed) 377 } 378 379 func (device *Device) Wait() chan struct{} { 380 return device.closed 381 } 382 383 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 384 if !device.isUp() { 385 return 386 } 387 388 device.peers.RLock() 389 for _, peer := range device.peers.keyMap { 390 peer.keypairs.RLock() 391 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 392 peer.keypairs.RUnlock() 393 if sendKeepalive { 394 peer.SendKeepalive() 395 } 396 } 397 device.peers.RUnlock() 398 } 399 400 // closeBindLocked closes the device's net.bind. 401 // The caller must hold the net mutex. 402 func closeBindLocked(device *Device) error { 403 var err error 404 netc := &device.net 405 if netc.netlinkCancel != nil { 406 netc.netlinkCancel.Cancel() 407 } 408 if netc.bind != nil { 409 err = netc.bind.Close() 410 } 411 netc.stopping.Wait() 412 return err 413 } 414 415 func (device *Device) Bind() conn.Bind { 416 device.net.Lock() 417 defer device.net.Unlock() 418 return device.net.bind 419 } 420 421 func (device *Device) BindSetMark(mark uint32) error { 422 device.net.Lock() 423 defer device.net.Unlock() 424 425 // check if modified 426 if device.net.fwmark == mark { 427 return nil 428 } 429 430 // update fwmark on existing bind 431 device.net.fwmark = mark 432 if device.isUp() && device.net.bind != nil { 433 if err := device.net.bind.SetMark(mark); err != nil { 434 return err 435 } 436 } 437 438 // clear cached source addresses 439 device.peers.RLock() 440 for _, peer := range device.peers.keyMap { 441 peer.Lock() 442 defer peer.Unlock() 443 if peer.endpoint != nil { 444 peer.endpoint.ClearSrc() 445 } 446 } 447 device.peers.RUnlock() 448 449 return nil 450 } 451 452 func (device *Device) BindUpdate() error { 453 device.net.Lock() 454 defer device.net.Unlock() 455 456 // close existing sockets 457 if err := closeBindLocked(device); err != nil { 458 return err 459 } 460 461 // open new sockets 462 if !device.isUp() { 463 return nil 464 } 465 466 // bind to new port 467 var err error 468 var recvFns []conn.ReceiveFunc 469 netc := &device.net 470 recvFns, netc.port, err = netc.bind.Open(netc.port) 471 if err != nil { 472 netc.port = 0 473 return err 474 } 475 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 476 if err != nil { 477 netc.bind.Close() 478 netc.port = 0 479 return err 480 } 481 482 // set fwmark 483 if netc.fwmark != 0 { 484 err = netc.bind.SetMark(netc.fwmark) 485 if err != nil { 486 return err 487 } 488 } 489 490 // clear cached source addresses 491 device.peers.RLock() 492 for _, peer := range device.peers.keyMap { 493 peer.Lock() 494 defer peer.Unlock() 495 if peer.endpoint != nil { 496 peer.endpoint.ClearSrc() 497 } 498 } 499 device.peers.RUnlock() 500 501 // start receiving routines 502 device.net.stopping.Add(len(recvFns)) 503 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 504 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 505 for _, fn := range recvFns { 506 go device.RoutineReceiveIncoming(fn) 507 } 508 509 device.log.Verbosef("UDP bind has been updated") 510 return nil 511 } 512 513 func (device *Device) BindClose() error { 514 device.net.Lock() 515 err := closeBindLocked(device) 516 device.net.Unlock() 517 return err 518 }