github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/device/device.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "context" 10 "encoding/base64" 11 "encoding/hex" 12 "fmt" 13 "runtime" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 "github.com/bugfan/wireguard-go/auth" 19 "github.com/bugfan/wireguard-go/conn" 20 "github.com/bugfan/wireguard-go/ratelimiter" 21 "github.com/bugfan/wireguard-go/rwcancel" 22 "github.com/bugfan/wireguard-go/tun" 23 ) 24 25 type Device struct { 26 state struct { 27 // state holds the device's state. It is accessed atomically. 28 // Use the device.deviceState method to read it. 29 // device.deviceState does not acquire the mutex, so it captures only a snapshot. 30 // During state transitions, the state variable is updated before the device itself. 31 // The state is thus either the current state of the device or 32 // the intended future state of the device. 33 // For example, while executing a call to Up, state will be deviceStateUp. 34 // There is no guarantee that that intended future state of the device 35 // will become the actual state; Up can fail. 36 // The device can also change state multiple times between time of check and time of use. 37 // Unsynchronized uses of state must therefore be advisory/best-effort only. 38 state uint32 // actually a deviceState, but typed uint32 for convenience 39 // stopping blocks until all inputs to Device have been closed. 40 stopping sync.WaitGroup 41 // mu protects state changes. 42 sync.Mutex 43 } 44 45 net struct { 46 stopping sync.WaitGroup 47 sync.RWMutex 48 bind conn.Bind // bind interface 49 netlinkCancel *rwcancel.RWCancel 50 port uint16 // listening port 51 fwmark uint32 // mark value (0 = disabled) 52 } 53 54 staticIdentity struct { 55 sync.RWMutex 56 privateKey NoisePrivateKey 57 publicKey NoisePublicKey 58 } 59 60 rate struct { 61 underLoadUntil int64 62 limiter ratelimiter.Ratelimiter 63 } 64 65 peers struct { 66 sync.RWMutex // protects keyMap 67 keyMap map[NoisePublicKey]*Peer 68 } 69 70 allowedips AllowedIPs 71 indexTable IndexTable 72 cookieChecker CookieChecker 73 74 pool struct { 75 messageBuffers *WaitPool 76 inboundElements *WaitPool 77 outboundElements *WaitPool 78 } 79 80 queue struct { 81 encryption *outboundQueue 82 decryption *inboundQueue 83 handshake *handshakeQueue 84 } 85 86 tun struct { 87 device tun.Device 88 mtu int32 89 } 90 91 ipcMutex sync.RWMutex 92 closed chan struct{} 93 log *Logger 94 } 95 96 // deviceState represents the state of a Device. 97 // There are three states: down, up, closed. 98 // Transitions: 99 // 100 // down -----+ 101 // ↑↓ ↓ 102 // up -> closed 103 // 104 type deviceState uint32 105 106 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState 107 const ( 108 deviceStateDown deviceState = iota 109 deviceStateUp 110 deviceStateClosed 111 ) 112 113 // deviceState returns device.state.state as a deviceState 114 // See those docs for how to interpret this value. 115 func (device *Device) deviceState() deviceState { 116 return deviceState(atomic.LoadUint32(&device.state.state)) 117 } 118 119 // isClosed reports whether the device is closed (or is closing). 120 // See device.state.state comments for how to interpret this value. 121 func (device *Device) isClosed() bool { 122 return device.deviceState() == deviceStateClosed 123 } 124 125 // isUp reports whether the device is up (or is attempting to come up). 126 // See device.state.state comments for how to interpret this value. 127 func (device *Device) isUp() bool { 128 return device.deviceState() == deviceStateUp 129 } 130 131 // Must hold device.peers.Lock() 132 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { 133 // stop routing and processing of packets 134 device.allowedips.RemoveByPeer(peer) 135 peer.Stop() 136 137 // remove from peer map 138 delete(device.peers.keyMap, key) 139 } 140 141 // changeState attempts to change the device state to match want. 142 func (device *Device) changeState(want deviceState) (err error) { 143 device.state.Lock() 144 defer device.state.Unlock() 145 old := device.deviceState() 146 if old == deviceStateClosed { 147 // once closed, always closed 148 device.log.Verbosef("Interface closed, ignored requested state %s", want) 149 return nil 150 } 151 switch want { 152 case old: 153 return nil 154 case deviceStateUp: 155 atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) 156 err = device.upLocked() 157 if err == nil { 158 break 159 } 160 fallthrough // up failed; bring the device all the way back down 161 case deviceStateDown: 162 atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) 163 errDown := device.downLocked() 164 if err == nil { 165 err = errDown 166 } 167 } 168 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) 169 return 170 } 171 172 // upLocked attempts to bring the device up and reports whether it succeeded. 173 // The caller must hold device.state.mu and is responsible for updating device.state.state. 174 func (device *Device) upLocked() error { 175 if err := device.BindUpdate(); err != nil { 176 device.log.Errorf("Unable to update bind: %v", err) 177 return err 178 } 179 180 device.peers.RLock() 181 for _, peer := range device.peers.keyMap { 182 peer.Start() 183 if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 184 peer.SendKeepalive() 185 } 186 } 187 device.peers.RUnlock() 188 return nil 189 } 190 191 // downLocked attempts to bring the device down. 192 // The caller must hold device.state.mu and is responsible for updating device.state.state. 193 func (device *Device) downLocked() error { 194 err := device.BindClose() 195 if err != nil { 196 device.log.Errorf("Bind close failed: %v", err) 197 } 198 199 device.peers.RLock() 200 for _, peer := range device.peers.keyMap { 201 peer.Stop() 202 } 203 device.peers.RUnlock() 204 return err 205 } 206 207 func (device *Device) Up() error { 208 return device.changeState(deviceStateUp) 209 } 210 211 func (device *Device) Down() error { 212 return device.changeState(deviceStateDown) 213 } 214 215 func (device *Device) IsUnderLoad() bool { 216 // check if currently under load 217 now := time.Now() 218 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 219 if underLoad { 220 atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) 221 return true 222 } 223 // check if recently under load 224 return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() 225 } 226 227 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { 228 // lock required resources 229 230 device.staticIdentity.Lock() 231 defer device.staticIdentity.Unlock() 232 233 if sk.Equals(device.staticIdentity.privateKey) { 234 return nil 235 } 236 237 device.peers.Lock() 238 defer device.peers.Unlock() 239 240 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) 241 for _, peer := range device.peers.keyMap { 242 peer.handshake.mutex.RLock() 243 lockedPeers = append(lockedPeers, peer) 244 } 245 246 // remove peers with matching public keys 247 248 publicKey := sk.publicKey() 249 for key, peer := range device.peers.keyMap { 250 if peer.handshake.remoteStatic.Equals(publicKey) { 251 peer.handshake.mutex.RUnlock() 252 removePeerLocked(device, peer, key) 253 peer.handshake.mutex.RLock() 254 } 255 } 256 257 // update key material 258 259 device.staticIdentity.privateKey = sk 260 device.staticIdentity.publicKey = publicKey 261 device.cookieChecker.Init(publicKey) 262 263 // do static-static DH pre-computations 264 265 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) 266 for _, peer := range device.peers.keyMap { 267 handshake := &peer.handshake 268 handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) 269 expiredPeers = append(expiredPeers, peer) 270 } 271 272 for _, peer := range lockedPeers { 273 peer.handshake.mutex.RUnlock() 274 } 275 for _, peer := range expiredPeers { 276 peer.ExpireCurrentKeypairs() 277 } 278 279 return nil 280 } 281 282 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { 283 device := new(Device) 284 device.state.state = uint32(deviceStateDown) 285 device.closed = make(chan struct{}) 286 device.log = logger 287 device.net.bind = bind 288 device.tun.device = tunDevice 289 mtu, err := device.tun.device.MTU() 290 if err != nil { 291 device.log.Errorf("Trouble determining MTU, assuming default: %v", err) 292 mtu = DefaultMTU 293 } 294 device.tun.mtu = int32(mtu) 295 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 296 device.rate.limiter.Init() 297 device.indexTable.Init() 298 device.PopulatePools() 299 300 // create queues 301 302 device.queue.handshake = newHandshakeQueue() 303 device.queue.encryption = newOutboundQueue() 304 device.queue.decryption = newInboundQueue() 305 306 // start workers 307 308 cpus := runtime.NumCPU() 309 device.state.stopping.Wait() 310 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 311 for i := 0; i < cpus; i++ { 312 go device.RoutineEncryption(i + 1) 313 go device.RoutineDecryption(i + 1) 314 go device.RoutineHandshake(i + 1) 315 } 316 317 device.state.stopping.Add(1) // RoutineReadFromTUN 318 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN 319 go device.RoutineReadFromTUN() 320 go device.RoutineTUNEventReader() 321 322 return device 323 } 324 325 func btoa(data []byte) []byte { 326 // Base64 Standard Encoding 327 sEnc := base64.StdEncoding.EncodeToString(data) 328 // fmt.Println(sEnc) // aGVsbG8gd29ybGQxMjM0NSE/JComKCknLUB+ 329 return []byte(sEnc) 330 } 331 332 func ParseNPK(pk NoisePublicKey) (string, []byte) { 333 buf := []byte{} 334 for i := 0; i < len(pk); i++ { 335 buf = append(buf, pk[i]) 336 } 337 return string(btoa(buf)), buf 338 } 339 340 func ToWgIpcSetString(bs []byte, key, ips string) string { 341 342 hexKey := hex.EncodeToString(bs) 343 344 return fmt.Sprintf(`public_key=%s 345 replace_allowed_ips=true 346 allowed_ip=%s`, hexKey, ips) 347 } 348 func (device *Device) MyWgSet(conf string) { 349 ctx, cancelFunc := context.WithTimeout(context.Background(), time.Second/2) //500ms timeout 350 go func() { 351 defer cancelFunc() 352 device.log.Verbosef("use context set peer config begin") 353 err := device.IpcSet(conf) 354 if err != nil { 355 device.log.Verbosef("device set config faild,error is %v,config is '%s'\n", err, conf) 356 } 357 device.log.Verbosef("use context set peer config end") 358 359 }() 360 <-ctx.Done() // wait goroutine 361 } 362 363 func (device *Device) LookupPeer(pk NoisePublicKey, master bool) *Peer { 364 365 // if already has,return it 366 device.peers.RLock() 367 o := device.peers.keyMap[pk] 368 device.peers.RUnlock() 369 if o != nil { 370 return o 371 } 372 373 // if dont't have ,go to auth server 374 str, bs := ParseNPK(pk) 375 fmt.Printf("[authentication] peer %s verify auth\n", str) 376 peer, err := auth.Verify(str) 377 if err != nil { 378 fmt.Printf("[authentication-faild] can not verify %s 's authority,error is %v\n", str, err) 379 return nil 380 } 381 fmt.Printf("[authentication-ok] %s verify ok\n", str) 382 383 // authorication and set peer to keyMap 384 conf := ToWgIpcSetString(bs, peer.PublicKey, peer.Address) 385 device.MyWgSet(conf) 386 387 // lock op move to final 388 device.peers.RLock() 389 defer device.peers.RUnlock() 390 391 return device.peers.keyMap[pk] 392 } 393 394 func (device *Device) RemovePeer(key NoisePublicKey) { 395 device.peers.Lock() 396 defer device.peers.Unlock() 397 // stop peer and remove from routing 398 399 peer, ok := device.peers.keyMap[key] 400 if ok { 401 removePeerLocked(device, peer, key) 402 } 403 } 404 405 func (device *Device) RemoveAllPeers() { 406 device.peers.Lock() 407 defer device.peers.Unlock() 408 409 for key, peer := range device.peers.keyMap { 410 removePeerLocked(device, peer, key) 411 } 412 413 device.peers.keyMap = make(map[NoisePublicKey]*Peer) 414 } 415 416 func (device *Device) Close() { 417 device.state.Lock() 418 defer device.state.Unlock() 419 if device.isClosed() { 420 return 421 } 422 atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) 423 device.log.Verbosef("Device closing") 424 425 device.tun.device.Close() 426 device.downLocked() 427 428 // Remove peers before closing queues, 429 // because peers assume that queues are active. 430 device.RemoveAllPeers() 431 432 // We kept a reference to the encryption and decryption queues, 433 // in case we started any new peers that might write to them. 434 // No new peers are coming; we are done with these queues. 435 device.queue.encryption.wg.Done() 436 device.queue.decryption.wg.Done() 437 device.queue.handshake.wg.Done() 438 device.state.stopping.Wait() 439 440 device.rate.limiter.Close() 441 442 device.log.Verbosef("Device closed") 443 close(device.closed) 444 } 445 446 func (device *Device) Wait() chan struct{} { 447 return device.closed 448 } 449 450 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { 451 if !device.isUp() { 452 return 453 } 454 455 device.peers.RLock() 456 for _, peer := range device.peers.keyMap { 457 peer.keypairs.RLock() 458 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 459 peer.keypairs.RUnlock() 460 if sendKeepalive { 461 peer.SendKeepalive() 462 } 463 } 464 device.peers.RUnlock() 465 } 466 467 // closeBindLocked closes the device's net.bind. 468 // The caller must hold the net mutex. 469 func closeBindLocked(device *Device) error { 470 var err error 471 netc := &device.net 472 if netc.netlinkCancel != nil { 473 netc.netlinkCancel.Cancel() 474 } 475 if netc.bind != nil { 476 err = netc.bind.Close() 477 } 478 netc.stopping.Wait() 479 return err 480 } 481 482 func (device *Device) Bind() conn.Bind { 483 device.net.Lock() 484 defer device.net.Unlock() 485 return device.net.bind 486 } 487 488 func (device *Device) BindSetMark(mark uint32) error { 489 device.net.Lock() 490 defer device.net.Unlock() 491 492 // check if modified 493 if device.net.fwmark == mark { 494 return nil 495 } 496 497 // update fwmark on existing bind 498 device.net.fwmark = mark 499 if device.isUp() && device.net.bind != nil { 500 if err := device.net.bind.SetMark(mark); err != nil { 501 return err 502 } 503 } 504 505 // clear cached source addresses 506 device.peers.RLock() 507 for _, peer := range device.peers.keyMap { 508 peer.Lock() 509 defer peer.Unlock() 510 if peer.endpoint != nil { 511 peer.endpoint.ClearSrc() 512 } 513 } 514 device.peers.RUnlock() 515 516 return nil 517 } 518 519 func (device *Device) BindUpdate() error { 520 device.net.Lock() 521 defer device.net.Unlock() 522 523 // close existing sockets 524 if err := closeBindLocked(device); err != nil { 525 return err 526 } 527 528 // open new sockets 529 if !device.isUp() { 530 return nil 531 } 532 533 // bind to new port 534 var err error 535 var recvFns []conn.ReceiveFunc 536 netc := &device.net 537 recvFns, netc.port, err = netc.bind.Open(netc.port) 538 if err != nil { 539 netc.port = 0 540 return err 541 } 542 netc.netlinkCancel, err = device.startRouteListener(netc.bind) 543 if err != nil { 544 netc.bind.Close() 545 netc.port = 0 546 return err 547 } 548 549 // set fwmark 550 if netc.fwmark != 0 { 551 err = netc.bind.SetMark(netc.fwmark) 552 if err != nil { 553 return err 554 } 555 } 556 557 // clear cached source addresses 558 device.peers.RLock() 559 for _, peer := range device.peers.keyMap { 560 peer.Lock() 561 defer peer.Unlock() 562 if peer.endpoint != nil { 563 peer.endpoint.ClearSrc() 564 } 565 } 566 device.peers.RUnlock() 567 568 // start receiving routines 569 device.net.stopping.Add(len(recvFns)) 570 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption 571 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake 572 for _, fn := range recvFns { 573 go device.RoutineReceiveIncoming(fn) 574 } 575 576 device.log.Verbosef("UDP bind has been updated") 577 return nil 578 } 579 580 func (device *Device) BindClose() error { 581 device.net.Lock() 582 err := closeBindLocked(device) 583 device.net.Unlock() 584 return err 585 }