github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "runtime" 10 "sync" 11 "time" 12 13 "github.com/sagernet/sing/common/atomic" 14 "github.com/sagernet/wireguard-go/conn" 15 "github.com/sagernet/wireguard-go/ratelimiter" 16 "github.com/sagernet/wireguard-go/rwcancel" 17 "github.com/sagernet/wireguard-go/tun" 18 ) 19 20 type Device struct { 21 state struct { 22 // state holds the device's state. It is accessed atomically. 23 // Use the device.deviceState method to read it. 24 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 25 // During state transitions, the state variable is updated before the device itself. 26 // The state is thus either the current state of the device or 27 // the intended future state of the device. 28 // For example, while executing a call to Up, state will be deviceStateUp. 29 // There is no guarantee that that intended future state of the device 30 // will become the actual state; Up can fail. 31 // The device can also change state multiple times between time of check and time of use. 32 // Unsynchronized uses of state must therefore be advisory/best-effort only. 33 state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience 34 // stopping blocks until all inputs to Device have been closed. 35 stopping sync.WaitGroup 36 // mu protects state changes. 37 sync.Mutex 38 } 39 40 net struct { 41 stopping sync.WaitGroup 42 sync.RWMutex 43 bind conn.Bind // bind interface 44 netlinkCancel *rwcancel.RWCancel 45 port uint16 // listening port 46 fwmark uint32 // mark value (0 = disabled) 47 brokenRoaming bool 48 } 49 50 staticIdentity struct { 51 sync.RWMutex 52 privateKey NoisePrivateKey 53 publicKey NoisePublicKey 54 } 55 56 peers struct { 57 sync.RWMutex // protects keyMap 58 keyMap map[NoisePublicKey]*Peer 59 } 60 61 rate struct { 62 underLoadUntil atomic.Int64 63 limiter ratelimiter.Ratelimiter 64 } 65 66 allowedips AllowedIPs 67 indexTable IndexTable 68 cookieChecker CookieChecker 69 70 pool struct { 71 inboundElementsContainer *WaitPool 72 outboundElementsContainer *WaitPool 73 messageBuffers *WaitPool 74 inboundElements *WaitPool 75 outboundElements *WaitPool 76 } 77 78 queue struct { 79 encryption *outboundQueue 80 decryption *inboundQueue 81 handshake *handshakeQueue 82 } 83 84 tun struct { 85 device tun.Device 86 mtu atomic.Int32 87 } 88 89 ipcMutex sync.RWMutex 90 closed chan struct{} 91 log *Logger 92 } 93 94 // deviceState represents the state of a Device. 95 // There are three states: down, up, closed. 96 // Transitions: 97 // 98 // down -----+ 99 // ↑↓ ↓ 100 // up -> closed 101 type deviceState uint32 102 103 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 104 const ( 105 deviceStateDown deviceState = iota 106 deviceStateUp 107 deviceStateClosed 108 ) 109 110 // deviceState returns device.state.state as a deviceState 111 // See those docs for how to interpret this value. 112 func (device *Device) deviceState() deviceState { 113 return deviceState(device.state.state.Load()) 114 } 115 116 // isClosed reports whether the device is closed (or is closing). 117 // See device.state.state comments for how to interpret this value. 118 func (device *Device) isClosed() bool { 119 return device.deviceState() == deviceStateClosed 120 } 121 122 // isUp reports whether the device is up (or is attempting to come up). 123 // See device.state.state comments for how to interpret this value. 124 func (device *Device) isUp() bool { 125 return device.deviceState() == deviceStateUp 126 } 127 128 // Must hold device.peers.Lock() 129 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 130 // stop routing and processing of packets 131 device.allowedips.RemoveByPeer(peer) 132 peer.Stop() 133 134 // remove from peer map 135 delete(device.peers.keyMap, key) 136 } 137 138 // changeState attempts to change the device state to match want. 139 func (device *Device) changeState(want deviceState) (err error) { 140 device.state.Lock() 141 defer device.state.Unlock() 142 old := device.deviceState() 143 if old == deviceStateClosed { 144 // once closed, always closed 145 device.log.Verbosef("Interface closed, ignored requested state %s", want) 146 return nil 147 } 148 switch want { 149 case old: 150 return nil 151 case deviceStateUp: 152 device.state.state.Store(uint32(deviceStateUp)) 153 err = device.upLocked() 154 if err == nil { 155 break 156 } 157 fallthrough // up failed; bring the device all the way back down 158 case deviceStateDown: 159 device.state.state.Store(uint32(deviceStateDown)) 160 errDown := device.downLocked() 161 if err == nil { 162 err = errDown 163 } 164 } 165 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 166 return 167 } 168 169 // upLocked attempts to bring the device up and reports whether it succeeded. 170 // The caller must hold device.state.mu and is responsible for updating device.state.state. 171 func (device *Device) upLocked() error { 172 if err := device.BindUpdate(); err != nil { 173 device.log.Errorf("Unable to update bind: %v", err) 174 return err 175 } 176 177 // The IPC set operation waits for peers to be created before calling Start() on them, 178 // so if there's a concurrent IPC set request happening, we should wait for it to complete. 179 device.ipcMutex.Lock() 180 defer device.ipcMutex.Unlock() 181 182 device.peers.RLock() 183 for _, peer := range device.peers.keyMap { 184 peer.Start() 185 if peer.persistentKeepaliveInterval.Load() > 0 { 186 peer.SendKeepalive() 187 } 188 } 189 device.peers.RUnlock() 190 return nil 191 } 192 193 // downLocked attempts to bring the device down. 194 // The caller must hold device.state.mu and is responsible for updating device.state.state. 195 func (device *Device) downLocked() error { 196 err := device.BindClose() 197 if err != nil { 198 device.log.Errorf("Bind close failed: %v", err) 199 } 200 201 device.peers.RLock() 202 for _, peer := range device.peers.keyMap { 203 peer.Stop() 204 } 205 device.peers.RUnlock() 206 return err 207 } 208 209 func (device *Device) Up() error { 210 return device.changeState(deviceStateUp) 211 } 212 213 func (device *Device) Down() error { 214 return device.changeState(deviceStateDown) 215 } 216 217 func (device *Device) IsUnderLoad() bool { 218 // check if currently under load 219 now := time.Now() 220 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 221 if underLoad { 222 device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) 223 return true 224 } 225 // check if recently under load 226 return device.rate.underLoadUntil.Load() > now.UnixNano() 227 } 228 229 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 230 // lock required resources 231 232 device.staticIdentity.Lock() 233 defer device.staticIdentity.Unlock() 234 235 if sk.Equals(device.staticIdentity.privateKey) { 236 return nil 237 } 238 239 device.peers.Lock() 240 defer device.peers.Unlock() 241 242 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 243 for _, peer := range device.peers.keyMap { 244 peer.handshake.mutex.RLock() 245 lockedPeers = append(lockedPeers, peer) 246 } 247 248 // remove peers with matching public keys 249 250 publicKey := sk.publicKey() 251 for key, peer := range device.peers.keyMap { 252 if peer.handshake.remoteStatic.Equals(publicKey) { 253 peer.handshake.mutex.RUnlock() 254 removePeerLocked(device, peer, key) 255 peer.handshake.mutex.RLock() 256 } 257 } 258 259 // update key material 260 261 device.staticIdentity.privateKey = sk 262 device.staticIdentity.publicKey = publicKey 263 device.cookieChecker.Init(publicKey) 264 265 // do static-static DH pre-computations 266 267 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 268 for _, peer := range device.peers.keyMap { 269 handshake := &peer.handshake 270 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 271 expiredPeers = append(expiredPeers, peer) 272 } 273 274 for _, peer := range lockedPeers { 275 peer.handshake.mutex.RUnlock() 276 } 277 for _, peer := range expiredPeers { 278 peer.ExpireCurrentKeypairs() 279 } 280 281 return nil 282 } 283 284 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, workers int) *Device { 285 device := new(Device) 286 device.state.state.Store(uint32(deviceStateDown)) 287 device.closed = make(chan struct{}) 288 device.log = logger 289 device.net.bind = bind 290 device.tun.device = tunDevice 291 mtu, err := device.tun.device.MTU() 292 if err != nil { 293 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 294 mtu = DefaultMTU 295 } 296 device.tun.mtu.Store(int32(mtu)) 297 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 298 device.rate.limiter.Init() 299 device.indexTable.Init() 300 301 device.PopulatePools() 302 303 // create queues 304 305 device.queue.handshake = newHandshakeQueue() 306 device.queue.encryption = newOutboundQueue() 307 device.queue.decryption = newInboundQueue() 308 309 // start workers 310 311 if workers == 0 { 312 workers = runtime.NumCPU() 313 } 314 device.state.stopping.Wait() 315 device.queue.encryption.wg.Add(workers) // One for each RoutineHandshake 316 for i := 0; i < workers; i++ { 317 go device.RoutineEncryption(i + 1) 318 go device.RoutineDecryption(i + 1) 319 go device.RoutineHandshake(i + 1) 320 } 321 322 device.state.stopping.Add(1) // RoutineReadFromTUN 323 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 324 go device.RoutineReadFromTUN() 325 go device.RoutineTUNEventReader() 326 327 return device 328 } 329 330 // BatchSize returns the BatchSize for the device as a whole which is the max of 331 // the bind batch size and the tun batch size. The batch size reported by device 332 // is the size used to construct memory pools, and is the allowed batch size for 333 // the lifetime of the device. 334 func (device *Device) BatchSize() int { 335 size := device.net.bind.BatchSize() 336 dSize := device.tun.device.BatchSize() 337 if size < dSize { 338 size = dSize 339 } 340 return size 341 } 342 343 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { 344 device.peers.RLock() 345 defer device.peers.RUnlock() 346 347 return device.peers.keyMap[pk] 348 } 349 350 func (device *Device) RemovePeer(key NoisePublicKey) { 351 device.peers.Lock() 352 defer device.peers.Unlock() 353 // stop peer and remove from routing 354 355 peer, ok := device.peers.keyMap[key] 356 if ok { 357 removePeerLocked(device, peer, key) 358 } 359 } 360 361 func (device *Device) RemoveAllPeers() { 362 device.peers.Lock() 363 defer device.peers.Unlock() 364 365 for key, peer := range device.peers.keyMap { 366 removePeerLocked(device, peer, key) 367 } 368 369 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 370 } 371 372 func (device *Device) Close() { 373 device.ipcMutex.Lock() 374 defer device.ipcMutex.Unlock() 375 device.state.Lock() 376 defer device.state.Unlock() 377 if device.isClosed() { 378 return 379 } 380 device.state.state.Store(uint32(deviceStateClosed)) 381 device.log.Verbosef("Device closing") 382 383 device.tun.device.Close() 384 device.downLocked() 385 386 // Remove peers before closing queues, 387 // because peers assume that queues are active. 388 device.RemoveAllPeers() 389 390 // We kept a reference to the encryption and decryption queues, 391 // in case we started any new peers that might write to them. 392 // No new peers are coming; we are done with these queues. 393 device.queue.encryption.wg.Done() 394 device.queue.decryption.wg.Done() 395 device.queue.handshake.wg.Done() 396 device.state.stopping.Wait() 397 398 device.rate.limiter.Close() 399 400 device.log.Verbosef("Device closed") 401 close(device.closed) 402 } 403 404 func (device *Device) Wait() chan struct{} { 405 return device.closed 406 } 407 408 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 409 if !device.isUp() { 410 return 411 } 412 413 device.peers.RLock() 414 for _, peer := range device.peers.keyMap { 415 peer.keypairs.RLock() 416 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 417 peer.keypairs.RUnlock() 418 if sendKeepalive { 419 peer.SendKeepalive() 420 } 421 } 422 device.peers.RUnlock() 423 } 424 425 // closeBindLocked closes the device's net.bind. 426 // The caller must hold the net mutex. 427 func closeBindLocked(device *Device) error { 428 var err error 429 netc := &device.net 430 if netc.netlinkCancel != nil { 431 netc.netlinkCancel.Cancel() 432 } 433 if netc.bind != nil { 434 err = netc.bind.Close() 435 } 436 netc.stopping.Wait() 437 return err 438 } 439 440 func (device *Device) Bind() conn.Bind { 441 device.net.Lock() 442 defer device.net.Unlock() 443 return device.net.bind 444 } 445 446 func (device *Device) BindSetMark(mark uint32) error { 447 device.net.Lock() 448 defer device.net.Unlock() 449 450 // check if modified 451 if device.net.fwmark == mark { 452 return nil 453 } 454 455 // update fwmark on existing bind 456 device.net.fwmark = mark 457 if device.isUp() && device.net.bind != nil { 458 if err := device.net.bind.SetMark(mark); err != nil { 459 return err 460 } 461 } 462 463 // clear cached source addresses 464 device.peers.RLock() 465 for _, peer := range device.peers.keyMap { 466 peer.Lock() 467 defer peer.Unlock() 468 if peer.endpoint != nil { 469 peer.endpoint.ClearSrc() 470 } 471 } 472 device.peers.RUnlock() 473 474 return nil 475 } 476 477 func (device *Device) BindUpdate() error { 478 device.net.Lock() 479 defer device.net.Unlock() 480 481 // close existing sockets 482 if err := closeBindLocked(device); err != nil { 483 return err 484 } 485 486 // open new sockets 487 if !device.isUp() { 488 return nil 489 } 490 491 // bind to new port 492 var err error 493 var recvFns []conn.ReceiveFunc 494 netc := &device.net 495 496 recvFns, netc.port, err = netc.bind.Open(netc.port) 497 if err != nil { 498 netc.port = 0 499 return err 500 } 501 502 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 503 if err != nil { 504 netc.bind.Close() 505 netc.port = 0 506 return err 507 } 508 509 // set fwmark 510 if netc.fwmark != 0 { 511 err = netc.bind.SetMark(netc.fwmark) 512 if err != nil { 513 return err 514 } 515 } 516 517 // clear cached source addresses 518 device.peers.RLock() 519 for _, peer := range device.peers.keyMap { 520 peer.Lock() 521 defer peer.Unlock() 522 if peer.endpoint != nil { 523 peer.endpoint.ClearSrc() 524 } 525 } 526 device.peers.RUnlock() 527 528 // start receiving routines 529 device.net.stopping.Add(len(recvFns)) 530 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 531 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 532 batchSize := netc.bind.BatchSize() 533 for _, fn := range recvFns { 534 go device.RoutineReceiveIncoming(batchSize, fn) 535 } 536 537 device.log.Verbosef("UDP bind has been updated") 538 return nil 539 } 540 541 func (device *Device) BindClose() error { 542 device.net.Lock() 543 err := closeBindLocked(device) 544 device.net.Unlock() 545 return err 546 }