github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/peer.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"errors"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/amnezia-vpn/amneziawg-go/conn"
    16  )
    17  
    18  type Peer struct {
    19  	isRunning         atomic.Bool
    20  	keypairs          Keypairs
    21  	handshake         Handshake
    22  	device            *Device
    23  	stopping          sync.WaitGroup // routines pending stop
    24  	txBytes           atomic.Uint64  // bytes send to peer (endpoint)
    25  	rxBytes           atomic.Uint64  // bytes received from peer
    26  	lastHandshakeNano atomic.Int64   // nano seconds since epoch
    27  
    28  	endpoint struct {
    29  		sync.Mutex
    30  		val            conn.Endpoint
    31  		clearSrcOnTx   bool // signal to val.ClearSrc() prior to next packet transmission
    32  		disableRoaming bool
    33  	}
    34  
    35  	timers struct {
    36  		retransmitHandshake     *Timer
    37  		sendKeepalive           *Timer
    38  		newHandshake            *Timer
    39  		zeroKeyMaterial         *Timer
    40  		persistentKeepalive     *Timer
    41  		handshakeAttempts       atomic.Uint32
    42  		needAnotherKeepalive    atomic.Bool
    43  		sentLastMinuteHandshake atomic.Bool
    44  	}
    45  
    46  	state struct {
    47  		sync.Mutex // protects against concurrent Start/Stop
    48  	}
    49  
    50  	queue struct {
    51  		staged   chan *QueueOutboundElementsContainer // staged packets before a handshake is available
    52  		outbound *autodrainingOutboundQueue           // sequential ordering of udp transmission
    53  		inbound  *autodrainingInboundQueue            // sequential ordering of tun writing
    54  	}
    55  
    56  	cookieGenerator             CookieGenerator
    57  	trieEntries                 list.List
    58  	persistentKeepaliveInterval atomic.Uint32
    59  }
    60  
    61  func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
    62  	if device.isClosed() {
    63  		return nil, errors.New("device closed")
    64  	}
    65  
    66  	// lock resources
    67  	device.staticIdentity.RLock()
    68  	defer device.staticIdentity.RUnlock()
    69  
    70  	device.peers.Lock()
    71  	defer device.peers.Unlock()
    72  
    73  	// check if over limit
    74  	if len(device.peers.keyMap) >= MaxPeers {
    75  		return nil, errors.New("too many peers")
    76  	}
    77  
    78  	// create peer
    79  	peer := new(Peer)
    80  
    81  	peer.cookieGenerator.Init(pk)
    82  	peer.device = device
    83  	peer.queue.outbound = newAutodrainingOutboundQueue(device)
    84  	peer.queue.inbound = newAutodrainingInboundQueue(device)
    85  	peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
    86  
    87  	// map public key
    88  	_, ok := device.peers.keyMap[pk]
    89  	if ok {
    90  		return nil, errors.New("adding existing peer")
    91  	}
    92  
    93  	// pre-compute DH
    94  	handshake := &peer.handshake
    95  	handshake.mutex.Lock()
    96  	handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
    97  	handshake.remoteStatic = pk
    98  	handshake.mutex.Unlock()
    99  
   100  	// reset endpoint
   101  	peer.endpoint.Lock()
   102  	peer.endpoint.val = nil
   103  	peer.endpoint.disableRoaming = false
   104  	peer.endpoint.clearSrcOnTx = false
   105  	peer.endpoint.Unlock()
   106  
   107  	// init timers
   108  	peer.timersInit()
   109  
   110  	// add
   111  	device.peers.keyMap[pk] = peer
   112  
   113  	return peer, nil
   114  }
   115  
   116  func (peer *Peer) SendBuffers(buffers [][]byte) error {
   117  	peer.device.net.RLock()
   118  	defer peer.device.net.RUnlock()
   119  
   120  	if peer.device.isClosed() {
   121  		return nil
   122  	}
   123  
   124  	peer.endpoint.Lock()
   125  	endpoint := peer.endpoint.val
   126  	if endpoint == nil {
   127  		peer.endpoint.Unlock()
   128  		return errors.New("no known endpoint for peer")
   129  	}
   130  	if peer.endpoint.clearSrcOnTx {
   131  		endpoint.ClearSrc()
   132  		peer.endpoint.clearSrcOnTx = false
   133  	}
   134  	peer.endpoint.Unlock()
   135  
   136  	err := peer.device.net.bind.Send(buffers, endpoint)
   137  	if err == nil {
   138  		var totalLen uint64
   139  		for _, b := range buffers {
   140  			totalLen += uint64(len(b))
   141  		}
   142  		peer.txBytes.Add(totalLen)
   143  	}
   144  	return err
   145  }
   146  
   147  func (peer *Peer) String() string {
   148  	// The awful goo that follows is identical to:
   149  	//
   150  	//   base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
   151  	//   abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
   152  	//   return fmt.Sprintf("peer(%s)", abbreviatedKey)
   153  	//
   154  	// except that it is considerably more efficient.
   155  	src := peer.handshake.remoteStatic
   156  	b64 := func(input byte) byte {
   157  		return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
   158  	}
   159  	b := []byte("peer(____…____)")
   160  	const first = len("peer(")
   161  	const second = len("peer(____…")
   162  	b[first+0] = b64((src[0] >> 2) & 63)
   163  	b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
   164  	b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
   165  	b[first+3] = b64(src[2] & 63)
   166  	b[second+0] = b64(src[29] & 63)
   167  	b[second+1] = b64((src[30] >> 2) & 63)
   168  	b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
   169  	b[second+3] = b64((src[31] << 2) & 63)
   170  	return string(b)
   171  }
   172  
   173  func (peer *Peer) Start() {
   174  	// should never start a peer on a closed device
   175  	if peer.device.isClosed() {
   176  		return
   177  	}
   178  
   179  	// prevent simultaneous start/stop operations
   180  	peer.state.Lock()
   181  	defer peer.state.Unlock()
   182  
   183  	if peer.isRunning.Load() {
   184  		return
   185  	}
   186  
   187  	device := peer.device
   188  	device.log.Verbosef("%v - Starting", peer)
   189  
   190  	// reset routine state
   191  	peer.stopping.Wait()
   192  	peer.stopping.Add(2)
   193  
   194  	peer.handshake.mutex.Lock()
   195  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   196  	peer.handshake.mutex.Unlock()
   197  
   198  	peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
   199  
   200  	peer.timersStart()
   201  
   202  	device.flushInboundQueue(peer.queue.inbound)
   203  	device.flushOutboundQueue(peer.queue.outbound)
   204  
   205  	// Use the device batch size, not the bind batch size, as the device size is
   206  	// the size of the batch pools.
   207  	batchSize := peer.device.BatchSize()
   208  	go peer.RoutineSequentialSender(batchSize)
   209  	go peer.RoutineSequentialReceiver(batchSize)
   210  
   211  	peer.isRunning.Store(true)
   212  }
   213  
   214  func (peer *Peer) ZeroAndFlushAll() {
   215  	device := peer.device
   216  
   217  	// clear key pairs
   218  
   219  	keypairs := &peer.keypairs
   220  	keypairs.Lock()
   221  	device.DeleteKeypair(keypairs.previous)
   222  	device.DeleteKeypair(keypairs.current)
   223  	device.DeleteKeypair(keypairs.next.Load())
   224  	keypairs.previous = nil
   225  	keypairs.current = nil
   226  	keypairs.next.Store(nil)
   227  	keypairs.Unlock()
   228  
   229  	// clear handshake state
   230  
   231  	handshake := &peer.handshake
   232  	handshake.mutex.Lock()
   233  	device.indexTable.Delete(handshake.localIndex)
   234  	handshake.Clear()
   235  	handshake.mutex.Unlock()
   236  
   237  	peer.FlushStagedPackets()
   238  }
   239  
   240  func (peer *Peer) ExpireCurrentKeypairs() {
   241  	handshake := &peer.handshake
   242  	handshake.mutex.Lock()
   243  	peer.device.indexTable.Delete(handshake.localIndex)
   244  	handshake.Clear()
   245  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   246  	handshake.mutex.Unlock()
   247  
   248  	keypairs := &peer.keypairs
   249  	keypairs.Lock()
   250  	if keypairs.current != nil {
   251  		keypairs.current.sendNonce.Store(RejectAfterMessages)
   252  	}
   253  	if next := keypairs.next.Load(); next != nil {
   254  		next.sendNonce.Store(RejectAfterMessages)
   255  	}
   256  	keypairs.Unlock()
   257  }
   258  
   259  func (peer *Peer) Stop() {
   260  	peer.state.Lock()
   261  	defer peer.state.Unlock()
   262  
   263  	if !peer.isRunning.Swap(false) {
   264  		return
   265  	}
   266  
   267  	peer.device.log.Verbosef("%v - Stopping", peer)
   268  
   269  	peer.timersStop()
   270  	// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
   271  	peer.queue.inbound.c <- nil
   272  	peer.queue.outbound.c <- nil
   273  	peer.stopping.Wait()
   274  	peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
   275  
   276  	peer.ZeroAndFlushAll()
   277  }
   278  
   279  func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
   280  	peer.endpoint.Lock()
   281  	defer peer.endpoint.Unlock()
   282  	if peer.endpoint.disableRoaming {
   283  		return
   284  	}
   285  	peer.endpoint.clearSrcOnTx = false
   286  	peer.endpoint.val = endpoint
   287  }
   288  
   289  func (peer *Peer) markEndpointSrcForClearing() {
   290  	peer.endpoint.Lock()
   291  	defer peer.endpoint.Unlock()
   292  	if peer.endpoint.val == nil {
   293  		return
   294  	}
   295  	peer.endpoint.clearSrcOnTx = true
   296  }