github.com/GFW-knocker/wireguard@v1.0.1/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/GFW-knocker/wireguard/conn" 15 "github.com/GFW-knocker/wireguard/ratelimiter" 16 "github.com/GFW-knocker/wireguard/rwcancel" 17 "github.com/GFW-knocker/wireguard/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 brokenRoaming bool 48 } 49 50 staticIdentity struct { 51 sync.RWMutex 52 privateKey NoisePrivateKey 53 publicKey NoisePublicKey 54 } 55 56 peers struct { 57 sync.RWMutex // protects keyMap 58 keyMap map[NoisePublicKey]*Peer 59 } 60 61 rate struct { 62 underLoadUntil atomic.Int64 63 limiter ratelimiter.Ratelimiter 64 } 65 66 allowedips AllowedIPs 67 indexTable IndexTable 68 cookieChecker CookieChecker 69 70 pool struct { 71 inboundElementsContainer *WaitPool 72 outboundElementsContainer *WaitPool 73 messageBuffers *WaitPool 74 inboundElements *WaitPool 75 outboundElements *WaitPool 76 } 77 78 queue struct { 79 encryption *outboundQueue 80 decryption *inboundQueue 81 handshake *handshakeQueue 82 } 83 84 tun struct { 85 device tun.Device 86 mtu atomic.Int32 87 } 88 89 ipcMutex sync.RWMutex 90 closed chan struct{} 91 log *Logger 92 } 93 94 // deviceState represents the state of a Device. 95 // There are three states: down, up, closed. 96 // Transitions: 97 // 98 // down -----+ 99 // ↑↓ ↓ 100 // up -> closed 101 type deviceState uint32 102 103 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 104 const ( 105 deviceStateDown deviceState = iota 106 deviceStateUp 107 deviceStateClosed 108 ) 109 110 // deviceState returns device.state.state as a deviceState 111 // See those docs for how to interpret this value. 112 func (device *Device) deviceState() deviceState { 113 return deviceState(device.state.state.Load()) 114 } 115 116 // isClosed reports whether the device is closed (or is closing). 117 // See device.state.state comments for how to interpret this value. 118 func (device *Device) isClosed() bool { 119 return device.deviceState() == deviceStateClosed 120 } 121 122 // isUp reports whether the device is up (or is attempting to come up). 123 // See device.state.state comments for how to interpret this value. 124 func (device *Device) isUp() bool { 125 return device.deviceState() == deviceStateUp 126 } 127 128 // Must hold device.peers.Lock() 129 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 130 // stop routing and processing of packets 131 device.allowedips.RemoveByPeer(peer) 132 peer.Stop() 133 134 // remove from peer map 135 delete(device.peers.keyMap, key) 136 } 137 138 // changeState attempts to change the device state to match want. 139 func (device *Device) changeState(want deviceState) (err error) { 140 device.state.Lock() 141 defer device.state.Unlock() 142 old := device.deviceState() 143 if old == deviceStateClosed { 144 // once closed, always closed 145 device.log.Verbosef("Interface closed, ignored requested state %s", want) 146 return nil 147 } 148 switch want { 149 case old: 150 return nil 151 case deviceStateUp: 152 device.state.state.Store(uint32(deviceStateUp)) 153 err = device.upLocked() 154 if err == nil { 155 break 156 } 157 fallthrough // up failed; bring the device all the way back down 158 case deviceStateDown: 159 device.state.state.Store(uint32(deviceStateDown)) 160 errDown := device.downLocked() 161 if err == nil { 162 err = errDown 163 } 164 } 165 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 166 return 167 } 168 169 // upLocked attempts to bring the device up and reports whether it succeeded. 170 // The caller must hold device.state.mu and is responsible for updating device.state.state. 171 func (device *Device) upLocked() error { 172 if err := device.BindUpdate(); err != nil { 173 device.log.Errorf("Unable to update bind: %v", err) 174 return err 175 } 176 177 // The IPC set operation waits for peers to be created before calling Start() on them, 178 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 179 device.ipcMutex.Lock() 180 defer device.ipcMutex.Unlock() 181 182 device.peers.RLock() 183 for _, peer := range device.peers.keyMap { 184 peer.Start() 185 if peer.persistentKeepaliveInterval.Load() > 0 { 186 peer.SendKeepalive() 187 } 188 } 189 device.peers.RUnlock() 190 return nil 191 } 192 193 // downLocked attempts to bring the device down. 194 // The caller must hold device.state.mu and is responsible for updating device.state.state. 195 func (device *Device) downLocked() error { 196 err := device.BindClose() 197 if err != nil { 198 device.log.Errorf("Bind close failed: %v", err) 199 } 200 201 device.peers.RLock() 202 for _, peer := range device.peers.keyMap { 203 peer.Stop() 204 } 205 device.peers.RUnlock() 206 return err 207 } 208 209 func (device *Device) Up() error { 210 return device.changeState(deviceStateUp) 211 } 212 213 func (device *Device) Down() error { 214 return device.changeState(deviceStateDown) 215 } 216 217 func (device *Device) IsUnderLoad() bool { 218 // check if currently under load 219 now := time.Now() 220 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 221 if underLoad { 222 device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) 223 return true 224 } 225 // check if recently under load 226 return device.rate.underLoadUntil.Load() > now.UnixNano() 227 } 228 229 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 230 // lock required resources 231 232 device.staticIdentity.Lock() 233 defer device.staticIdentity.Unlock() 234 235 if sk.Equals(device.staticIdentity.privateKey) { 236 return nil 237 } 238 239 device.peers.Lock() 240 defer device.peers.Unlock() 241 242 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 243 for _, peer := range device.peers.keyMap { 244 peer.handshake.mutex.RLock() 245 lockedPeers = append(lockedPeers, peer) 246 } 247 248 // remove peers with matching public keys 249 250 publicKey := sk.publicKey() 251 for key, peer := range device.peers.keyMap { 252 if peer.handshake.remoteStatic.Equals(publicKey) { 253 peer.handshake.mutex.RUnlock() 254 removePeerLocked(device, peer, key) 255 peer.handshake.mutex.RLock() 256 } 257 } 258 259 // update key material 260 261 device.staticIdentity.privateKey = sk 262 device.staticIdentity.publicKey = publicKey 263 device.cookieChecker.Init(publicKey) 264 265 // do static-static DH pre-computations 266 267 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 268 for _, peer := range device.peers.keyMap { 269 handshake := &peer.handshake 270 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 271 expiredPeers = append(expiredPeers, peer) 272 } 273 274 for _, peer := range lockedPeers { 275 peer.handshake.mutex.RUnlock() 276 } 277 for _, peer := range expiredPeers { 278 peer.ExpireCurrentKeypairs() 279 } 280 281 return nil 282 } 283 284 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 285 device := new(Device) 286 device.state.state.Store(uint32(deviceStateDown)) 287 device.closed = make(chan struct{}) 288 device.log = logger 289 device.net.bind = bind 290 device.tun.device = tunDevice 291 mtu, err := device.tun.device.MTU() 292 if err != nil { 293 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 294 mtu = DefaultMTU 295 } 296 device.tun.mtu.Store(int32(mtu)) 297 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 298 device.rate.limiter.Init() 299 device.indexTable.Init() 300 301 device.PopulatePools() 302 303 // create queues 304 305 device.queue.handshake = newHandshakeQueue() 306 device.queue.encryption = newOutboundQueue() 307 device.queue.decryption = newInboundQueue() 308 309 // start workers 310 311 cpus := runtime.NumCPU() 312 device.state.stopping.Wait() 313 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 314 for i := 0; i < cpus; i++ { 315 go device.RoutineEncryption(i + 1) 316 go device.RoutineDecryption(i + 1) 317 go device.RoutineHandshake(i + 1) 318 } 319 320 device.state.stopping.Add(1) // RoutineReadFromTUN 321 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 322 go device.RoutineReadFromTUN() 323 go device.RoutineTUNEventReader() 324 325 return device 326 } 327 328 // BatchSize returns the BatchSize for the device as a whole which is the max of 329 // the bind batch size and the tun batch size. The batch size reported by device 330 // is the size used to construct memory pools, and is the allowed batch size for 331 // the lifetime of the device. 332 func (device *Device) BatchSize() int { 333 size := device.net.bind.BatchSize() 334 dSize := device.tun.device.BatchSize() 335 if size < dSize { 336 size = dSize 337 } 338 return size 339 } 340 341 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 342 device.peers.RLock() 343 defer device.peers.RUnlock() 344 345 return device.peers.keyMap[pk] 346 } 347 348 func (device *Device) RemovePeer(key NoisePublicKey) { 349 device.peers.Lock() 350 defer device.peers.Unlock() 351 // stop peer and remove from routing 352 353 peer, ok := device.peers.keyMap[key] 354 if ok { 355 removePeerLocked(device, peer, key) 356 } 357 } 358 359 func (device *Device) RemoveAllPeers() { 360 device.peers.Lock() 361 defer device.peers.Unlock() 362 363 for key, peer := range device.peers.keyMap { 364 removePeerLocked(device, peer, key) 365 } 366 367 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 368 } 369 370 func (device *Device) Close() { 371 device.state.Lock() 372 defer device.state.Unlock() 373 device.ipcMutex.Lock() 374 defer device.ipcMutex.Unlock() 375 if device.isClosed() { 376 return 377 } 378 device.state.state.Store(uint32(deviceStateClosed)) 379 device.log.Verbosef("Device closing") 380 381 device.tun.device.Close() 382 device.downLocked() 383 384 // Remove peers before closing queues, 385 // because peers assume that queues are active. 386 device.RemoveAllPeers() 387 388 // We kept a reference to the encryption and decryption queues, 389 // in case we started any new peers that might write to them. 390 // No new peers are coming; we are done with these queues. 391 device.queue.encryption.wg.Done() 392 device.queue.decryption.wg.Done() 393 device.queue.handshake.wg.Done() 394 device.state.stopping.Wait() 395 396 device.rate.limiter.Close() 397 398 device.log.Verbosef("Device closed") 399 close(device.closed) 400 } 401 402 func (device *Device) Wait() chan struct{} { 403 return device.closed 404 } 405 406 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 407 if !device.isUp() { 408 return 409 } 410 411 device.peers.RLock() 412 for _, peer := range device.peers.keyMap { 413 peer.keypairs.RLock() 414 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 415 peer.keypairs.RUnlock() 416 if sendKeepalive { 417 peer.SendKeepalive() 418 } 419 } 420 device.peers.RUnlock() 421 } 422 423 // closeBindLocked closes the device's net.bind. 424 // The caller must hold the net mutex. 425 func closeBindLocked(device *Device) error { 426 var err error 427 netc := &device.net 428 if netc.netlinkCancel != nil { 429 netc.netlinkCancel.Cancel() 430 } 431 if netc.bind != nil { 432 err = netc.bind.Close() 433 } 434 netc.stopping.Wait() 435 return err 436 } 437 438 func (device *Device) Bind() conn.Bind { 439 device.net.Lock() 440 defer device.net.Unlock() 441 return device.net.bind 442 } 443 444 func (device *Device) BindSetMark(mark uint32) error { 445 device.net.Lock() 446 defer device.net.Unlock() 447 448 // check if modified 449 if device.net.fwmark == mark { 450 return nil 451 } 452 453 // update fwmark on existing bind 454 device.net.fwmark = mark 455 if device.isUp() && device.net.bind != nil { 456 if err := device.net.bind.SetMark(mark); err != nil { 457 return err 458 } 459 } 460 461 // clear cached source addresses 462 device.peers.RLock() 463 for _, peer := range device.peers.keyMap { 464 peer.markEndpointSrcForClearing() 465 } 466 device.peers.RUnlock() 467 468 return nil 469 } 470 471 func (device *Device) BindUpdate() error { 472 device.net.Lock() 473 defer device.net.Unlock() 474 475 // close existing sockets 476 if err := closeBindLocked(device); err != nil { 477 return err 478 } 479 480 // open new sockets 481 if !device.isUp() { 482 return nil 483 } 484 485 // bind to new port 486 var err error 487 var recvFns []conn.ReceiveFunc 488 netc := &device.net 489 490 recvFns, netc.port, err = netc.bind.Open(netc.port) 491 if err != nil { 492 netc.port = 0 493 return err 494 } 495 496 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 497 if err != nil { 498 netc.bind.Close() 499 netc.port = 0 500 return err 501 } 502 503 // set fwmark 504 if netc.fwmark != 0 { 505 err = netc.bind.SetMark(netc.fwmark) 506 if err != nil { 507 return err 508 } 509 } 510 511 // clear cached source addresses 512 device.peers.RLock() 513 for _, peer := range device.peers.keyMap { 514 peer.markEndpointSrcForClearing() 515 } 516 device.peers.RUnlock() 517 518 // start receiving routines 519 device.net.stopping.Add(len(recvFns)) 520 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 521 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 522 batchSize := netc.bind.BatchSize() 523 for _, fn := range recvFns { 524 go device.RoutineReceiveIncoming(batchSize, fn) 525 } 526 527 device.log.Verbosef("UDP bind has been updated") 528 return nil 529 } 530 531 func (device *Device) BindClose() error { 532 device.net.Lock() 533 err := closeBindLocked(device) 534 device.net.Unlock() 535 return err 536 }