github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/device.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"runtime"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/cawidtu/notwireguard-go/conn"
    15  	"github.com/cawidtu/notwireguard-go/ratelimiter"
    16  	"github.com/cawidtu/notwireguard-go/rwcancel"
    17  	"github.com/cawidtu/notwireguard-go/tun"
    18  )
    19  
    20  type Device struct {
    21  	state struct {
    22  		// state holds the device's state. It is accessed atomically.
    23  		// Use the device.deviceState method to read it.
    24  		// device.deviceState does not acquire the mutex, so it captures only a snapshot.
    25  		// During state transitions, the state variable is updated before the device itself.
    26  		// The state is thus either the current state of the device or
    27  		// the intended future state of the device.
    28  		// For example, while executing a call to Up, state will be deviceStateUp.
    29  		// There is no guarantee that that intended future state of the device
    30  		// will become the actual state; Up can fail.
    31  		// The device can also change state multiple times between time of check and time of use.
    32  		// Unsynchronized uses of state must therefore be advisory/best-effort only.
    33  		state uint32 // actually a deviceState, but typed uint32 for convenience
    34  		// stopping blocks until all inputs to Device have been closed.
    35  		stopping sync.WaitGroup
    36  		// mu protects state changes.
    37  		sync.Mutex
    38  	}
    39  
    40  	net struct {
    41  		stopping sync.WaitGroup
    42  		sync.RWMutex
    43  		bind          conn.Bind // bind interface
    44  		netlinkCancel *rwcancel.RWCancel
    45  		port          uint16 // listening port
    46  		fwmark        uint32 // mark value (0 = disabled)
    47  		brokenRoaming bool
    48  	}
    49  
    50  	staticIdentity struct {
    51  		sync.RWMutex
    52  		privateKey NoisePrivateKey
    53  		publicKey  NoisePublicKey
    54                  // new
    55  		obfuscator [NoisePublicKeySize]byte
    56  	}
    57  
    58  	peers struct {
    59  		sync.RWMutex // protects keyMap
    60  		keyMap       map[NoisePublicKey]*Peer
    61  	}
    62  
    63  	// Keep this 8-byte aligned
    64  	rate struct {
    65  		underLoadUntil int64
    66  		limiter        ratelimiter.Ratelimiter
    67  	}
    68  
    69  	allowedips    AllowedIPs
    70  	indexTable    IndexTable
    71  	cookieChecker CookieChecker
    72  
    73  	pool struct {
    74  		messageBuffers   *WaitPool
    75  		inboundElements  *WaitPool
    76  		outboundElements *WaitPool
    77  	}
    78  
    79  	queue struct {
    80  		encryption *outboundQueue
    81  		decryption *inboundQueue
    82  		handshake  *handshakeQueue
    83  	}
    84  
    85  	tun struct {
    86  		device tun.Device
    87  		mtu    int32
    88  	}
    89  
    90  	ipcMutex sync.RWMutex
    91  	closed   chan struct{}
    92  	log      *Logger
    93  }
    94  
    95  // deviceState represents the state of a Device.
    96  // There are three states: down, up, closed.
    97  // Transitions:
    98  //
    99  //   down -----+
   100  //     ↑↓      ↓
   101  //     up -> closed
   102  //
   103  type deviceState uint32
   104  
   105  //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
   106  const (
   107  	deviceStateDown deviceState = iota
   108  	deviceStateUp
   109  	deviceStateClosed
   110  )
   111  
   112  // deviceState returns device.state.state as a deviceState
   113  // See those docs for how to interpret this value.
   114  func (device *Device) deviceState() deviceState {
   115  	return deviceState(atomic.LoadUint32(&device.state.state))
   116  }
   117  
   118  // isClosed reports whether the device is closed (or is closing).
   119  // See device.state.state comments for how to interpret this value.
   120  func (device *Device) isClosed() bool {
   121  	return device.deviceState() == deviceStateClosed
   122  }
   123  
   124  // isUp reports whether the device is up (or is attempting to come up).
   125  // See device.state.state comments for how to interpret this value.
   126  func (device *Device) isUp() bool {
   127  	return device.deviceState() == deviceStateUp
   128  }
   129  
   130  // Must hold device.peers.Lock()
   131  func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
   132  	// stop routing and processing of packets
   133  	device.allowedips.RemoveByPeer(peer)
   134  	peer.Stop()
   135  
   136  	// remove from peer map
   137  	delete(device.peers.keyMap, key)
   138  }
   139  
   140  // changeState attempts to change the device state to match want.
   141  func (device *Device) changeState(want deviceState) (err error) {
   142  	device.state.Lock()
   143  	defer device.state.Unlock()
   144  	old := device.deviceState()
   145  	if old == deviceStateClosed {
   146  		// once closed, always closed
   147  		device.log.Verbosef("Interface closed, ignored requested state %s", want)
   148  		return nil
   149  	}
   150  	switch want {
   151  	case old:
   152  		return nil
   153  	case deviceStateUp:
   154  		atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
   155  		err = device.upLocked()
   156  		if err == nil {
   157  			break
   158  		}
   159  		fallthrough // up failed; bring the device all the way back down
   160  	case deviceStateDown:
   161  		atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
   162  		errDown := device.downLocked()
   163  		if err == nil {
   164  			err = errDown
   165  		}
   166  	}
   167  	device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
   168  	return
   169  }
   170  
   171  // upLocked attempts to bring the device up and reports whether it succeeded.
   172  // The caller must hold device.state.mu and is responsible for updating device.state.state.
   173  func (device *Device) upLocked() error {
   174  	if err := device.BindUpdate(); err != nil {
   175  		device.log.Errorf("Unable to update bind: %v", err)
   176  		return err
   177  	}
   178  
   179  	// The IPC set operation waits for peers to be created before calling Start() on them,
   180  	// so if there's a concurrent IPC set request happening, we should wait for it to complete.
   181  	device.ipcMutex.Lock()
   182  	defer device.ipcMutex.Unlock()
   183  
   184  	device.peers.RLock()
   185  	for _, peer := range device.peers.keyMap {
   186  		peer.Start()
   187  		if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
   188  			peer.SendKeepalive()
   189  		}
   190  	}
   191  	device.peers.RUnlock()
   192  	return nil
   193  }
   194  
   195  // downLocked attempts to bring the device down.
   196  // The caller must hold device.state.mu and is responsible for updating device.state.state.
   197  func (device *Device) downLocked() error {
   198  	err := device.BindClose()
   199  	if err != nil {
   200  		device.log.Errorf("Bind close failed: %v", err)
   201  	}
   202  
   203  	device.peers.RLock()
   204  	for _, peer := range device.peers.keyMap {
   205  		peer.Stop()
   206  	}
   207  	device.peers.RUnlock()
   208  	return err
   209  }
   210  
   211  func (device *Device) Up() error {
   212  	return device.changeState(deviceStateUp)
   213  }
   214  
   215  func (device *Device) Down() error {
   216  	return device.changeState(deviceStateDown)
   217  }
   218  
   219  func (device *Device) IsUnderLoad() bool {
   220  	// check if currently under load
   221  	now := time.Now()
   222  	underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
   223  	if underLoad {
   224  		atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
   225  		return true
   226  	}
   227  	// check if recently under load
   228  	return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
   229  }
   230  
   231  func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
   232  	// lock required resources
   233  
   234  	device.staticIdentity.Lock()
   235  	defer device.staticIdentity.Unlock()
   236  
   237  	if sk.Equals(device.staticIdentity.privateKey) {
   238  		return nil
   239  	}
   240  
   241  	device.peers.Lock()
   242  	defer device.peers.Unlock()
   243  
   244  	lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
   245  	for _, peer := range device.peers.keyMap {
   246  		peer.handshake.mutex.RLock()
   247  		lockedPeers = append(lockedPeers, peer)
   248  	}
   249  
   250  	// remove peers with matching public keys
   251  
   252  	publicKey := sk.publicKey()
   253  	for key, peer := range device.peers.keyMap {
   254  		if peer.handshake.remoteStatic.Equals(publicKey) {
   255  			peer.handshake.mutex.RUnlock()
   256  			removePeerLocked(device, peer, key)
   257  			peer.handshake.mutex.RLock()
   258  		}
   259  	}
   260  
   261  	// update key material
   262  
   263  	device.staticIdentity.privateKey = sk
   264  	device.staticIdentity.publicKey = publicKey
   265          device.staticIdentity.obfuscator = wgNoiseCreateObfuscator(device.staticIdentity.publicKey)
   266  	device.cookieChecker.Init(publicKey)
   267  
   268  	// do static-static DH pre-computations
   269  
   270  	expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
   271  	for _, peer := range device.peers.keyMap {
   272  		handshake := &peer.handshake
   273  		handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
   274  		expiredPeers = append(expiredPeers, peer)
   275  	}
   276  
   277  	for _, peer := range lockedPeers {
   278  		peer.handshake.mutex.RUnlock()
   279  	}
   280  	for _, peer := range expiredPeers {
   281  		peer.ExpireCurrentKeypairs()
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
   288  	device := new(Device)
   289  	device.state.state = uint32(deviceStateDown)
   290  	device.closed = make(chan struct{})
   291  	device.log = logger
   292  	device.net.bind = bind
   293  	device.tun.device = tunDevice
   294  	mtu, err := device.tun.device.MTU()
   295  	if err != nil {
   296  		device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
   297  		mtu = DefaultMTU
   298  	}
   299  	device.tun.mtu = int32(mtu)
   300  	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
   301  	device.rate.limiter.Init()
   302  	device.indexTable.Init()
   303  	device.PopulatePools()
   304  
   305  	// create queues
   306  
   307  	device.queue.handshake = newHandshakeQueue()
   308  	device.queue.encryption = newOutboundQueue()
   309  	device.queue.decryption = newInboundQueue()
   310  
   311  	// start workers
   312  
   313  	cpus := runtime.NumCPU()
   314  	device.state.stopping.Wait()
   315  	device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
   316  	for i := 0; i < cpus; i++ {
   317  		go device.RoutineEncryption(i + 1)
   318  		go device.RoutineDecryption(i + 1)
   319  		go device.RoutineHandshake(i + 1)
   320  	}
   321  
   322  	device.state.stopping.Add(1)      // RoutineReadFromTUN
   323  	device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
   324  	go device.RoutineReadFromTUN()
   325  	go device.RoutineTUNEventReader()
   326  
   327  	return device
   328  }
   329  
   330  func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
   331  	device.peers.RLock()
   332  	defer device.peers.RUnlock()
   333  
   334  	return device.peers.keyMap[pk]
   335  }
   336  
   337  func (device *Device) RemovePeer(key NoisePublicKey) {
   338  	device.peers.Lock()
   339  	defer device.peers.Unlock()
   340  	// stop peer and remove from routing
   341  
   342  	peer, ok := device.peers.keyMap[key]
   343  	if ok {
   344  		removePeerLocked(device, peer, key)
   345  	}
   346  }
   347  
   348  func (device *Device) RemoveAllPeers() {
   349  	device.peers.Lock()
   350  	defer device.peers.Unlock()
   351  
   352  	for key, peer := range device.peers.keyMap {
   353  		removePeerLocked(device, peer, key)
   354  	}
   355  
   356  	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
   357  }
   358  
   359  func (device *Device) Close() {
   360  	device.state.Lock()
   361  	defer device.state.Unlock()
   362  	if device.isClosed() {
   363  		return
   364  	}
   365  	atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
   366  	device.log.Verbosef("Device closing")
   367  
   368  	device.tun.device.Close()
   369  	device.downLocked()
   370  
   371  	// Remove peers before closing queues,
   372  	// because peers assume that queues are active.
   373  	device.RemoveAllPeers()
   374  
   375  	// We kept a reference to the encryption and decryption queues,
   376  	// in case we started any new peers that might write to them.
   377  	// No new peers are coming; we are done with these queues.
   378  	device.queue.encryption.wg.Done()
   379  	device.queue.decryption.wg.Done()
   380  	device.queue.handshake.wg.Done()
   381  	device.state.stopping.Wait()
   382  
   383  	device.rate.limiter.Close()
   384  
   385  	device.log.Verbosef("Device closed")
   386  	close(device.closed)
   387  }
   388  
   389  func (device *Device) Wait() chan struct{} {
   390  	return device.closed
   391  }
   392  
   393  func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
   394  	if !device.isUp() {
   395  		return
   396  	}
   397  
   398  	device.peers.RLock()
   399  	for _, peer := range device.peers.keyMap {
   400  		peer.keypairs.RLock()
   401  		sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
   402  		peer.keypairs.RUnlock()
   403  		if sendKeepalive {
   404  			peer.SendKeepalive()
   405  		}
   406  	}
   407  	device.peers.RUnlock()
   408  }
   409  
   410  // closeBindLocked closes the device's net.bind.
   411  // The caller must hold the net mutex.
   412  func closeBindLocked(device *Device) error {
   413  	var err error
   414  	netc := &device.net
   415  	if netc.netlinkCancel != nil {
   416  		netc.netlinkCancel.Cancel()
   417  	}
   418  	if netc.bind != nil {
   419  		err = netc.bind.Close()
   420  	}
   421  	netc.stopping.Wait()
   422  	return err
   423  }
   424  
   425  func (device *Device) Bind() conn.Bind {
   426  	device.net.Lock()
   427  	defer device.net.Unlock()
   428  	return device.net.bind
   429  }
   430  
   431  func (device *Device) BindSetMark(mark uint32) error {
   432  	device.net.Lock()
   433  	defer device.net.Unlock()
   434  
   435  	// check if modified
   436  	if device.net.fwmark == mark {
   437  		return nil
   438  	}
   439  
   440  	// update fwmark on existing bind
   441  	device.net.fwmark = mark
   442  	if device.isUp() && device.net.bind != nil {
   443  		if err := device.net.bind.SetMark(mark); err != nil {
   444  			return err
   445  		}
   446  	}
   447  
   448  	// clear cached source addresses
   449  	device.peers.RLock()
   450  	for _, peer := range device.peers.keyMap {
   451  		peer.Lock()
   452  		defer peer.Unlock()
   453  		if peer.endpoint != nil {
   454  			peer.endpoint.ClearSrc()
   455  		}
   456  	}
   457  	device.peers.RUnlock()
   458  
   459  	return nil
   460  }
   461  
   462  func (device *Device) BindUpdate() error {
   463  	device.net.Lock()
   464  	defer device.net.Unlock()
   465  
   466  	// close existing sockets
   467  	if err := closeBindLocked(device); err != nil {
   468  		return err
   469  	}
   470  
   471  	// open new sockets
   472  	if !device.isUp() {
   473  		return nil
   474  	}
   475  
   476  	// bind to new port
   477  	var err error
   478  	var recvFns []conn.ReceiveFunc
   479  	netc := &device.net
   480  	recvFns, netc.port, err = netc.bind.Open(netc.port)
   481  	if err != nil {
   482  		netc.port = 0
   483  		return err
   484  	}
   485  	netc.netlinkCancel, err = device.startRouteListener(netc.bind)
   486  	if err != nil {
   487  		netc.bind.Close()
   488  		netc.port = 0
   489  		return err
   490  	}
   491  
   492  	// set fwmark
   493  	if netc.fwmark != 0 {
   494  		err = netc.bind.SetMark(netc.fwmark)
   495  		if err != nil {
   496  			return err
   497  		}
   498  	}
   499  
   500  	// clear cached source addresses
   501  	device.peers.RLock()
   502  	for _, peer := range device.peers.keyMap {
   503  		peer.Lock()
   504  		defer peer.Unlock()
   505  		if peer.endpoint != nil {
   506  			peer.endpoint.ClearSrc()
   507  		}
   508  	}
   509  	device.peers.RUnlock()
   510  
   511  	// start receiving routines
   512  	device.net.stopping.Add(len(recvFns))
   513  	device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
   514  	device.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
   515  	for _, fn := range recvFns {
   516  		go device.RoutineReceiveIncoming(fn)
   517  	}
   518  
   519  	device.log.Verbosef("UDP bind has been updated")
   520  	return nil
   521  }
   522  
   523  func (device *Device) BindClose() error {
   524  	device.net.Lock()
   525  	err := closeBindLocked(device)
   526  	device.net.Unlock()
   527  	return err
   528  }