github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/peer.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"errors"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/sagernet/sing/common/atomic"
    15  	"github.com/sagernet/wireguard-go/conn"
    16  )
    17  
    18  type Peer struct {
    19  	isRunning         atomic.Bool
    20  	sync.RWMutex      // Mostly protects endpoint, but is generally taken whenever we modify peer
    21  	keypairs          Keypairs
    22  	handshake         Handshake
    23  	device            *Device
    24  	endpoint          conn.Endpoint
    25  	stopping          sync.WaitGroup // routines pending stop
    26  	txBytes           atomic.Uint64  // bytes send to peer (endpoint)
    27  	rxBytes           atomic.Uint64  // bytes received from peer
    28  	lastHandshakeNano atomic.Int64   // nano seconds since epoch
    29  
    30  	disableRoaming bool
    31  
    32  	timers struct {
    33  		retransmitHandshake     *Timer
    34  		sendKeepalive           *Timer
    35  		newHandshake            *Timer
    36  		zeroKeyMaterial         *Timer
    37  		persistentKeepalive     *Timer
    38  		handshakeAttempts       atomic.Uint32
    39  		needAnotherKeepalive    atomic.Bool
    40  		sentLastMinuteHandshake atomic.Bool
    41  	}
    42  
    43  	state struct {
    44  		sync.Mutex // protects against concurrent Start/Stop
    45  	}
    46  
    47  	queue struct {
    48  		staged   chan *QueueOutboundElementsContainer // staged packets before a handshake is available
    49  		outbound *autodrainingOutboundQueue           // sequential ordering of udp transmission
    50  		inbound  *autodrainingInboundQueue            // sequential ordering of tun writing
    51  	}
    52  
    53  	cookieGenerator             CookieGenerator
    54  	trieEntries                 list.List
    55  	persistentKeepaliveInterval atomic.Uint32
    56  }
    57  
    58  func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
    59  	if device.isClosed() {
    60  		return nil, errors.New("device closed")
    61  	}
    62  
    63  	// lock resources
    64  	device.staticIdentity.RLock()
    65  	defer device.staticIdentity.RUnlock()
    66  
    67  	device.peers.Lock()
    68  	defer device.peers.Unlock()
    69  
    70  	// check if over limit
    71  	if len(device.peers.keyMap) >= MaxPeers {
    72  		return nil, errors.New("too many peers")
    73  	}
    74  
    75  	// create peer
    76  	peer := new(Peer)
    77  	peer.Lock()
    78  	defer peer.Unlock()
    79  
    80  	peer.cookieGenerator.Init(pk)
    81  	peer.device = device
    82  	peer.queue.outbound = newAutodrainingOutboundQueue(device)
    83  	peer.queue.inbound = newAutodrainingInboundQueue(device)
    84  	peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
    85  
    86  	// map public key
    87  	_, ok := device.peers.keyMap[pk]
    88  	if ok {
    89  		return nil, errors.New("adding existing peer")
    90  	}
    91  
    92  	// pre-compute DH
    93  	handshake := &peer.handshake
    94  	handshake.mutex.Lock()
    95  	handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
    96  	handshake.remoteStatic = pk
    97  	handshake.mutex.Unlock()
    98  
    99  	// reset endpoint
   100  	peer.endpoint = nil
   101  
   102  	// init timers
   103  	peer.timersInit()
   104  
   105  	// add
   106  	device.peers.keyMap[pk] = peer
   107  
   108  	return peer, nil
   109  }
   110  
   111  func (peer *Peer) SendBuffers(buffers [][]byte) error {
   112  	peer.device.net.RLock()
   113  	defer peer.device.net.RUnlock()
   114  
   115  	if peer.device.isClosed() {
   116  		return nil
   117  	}
   118  
   119  	peer.RLock()
   120  	defer peer.RUnlock()
   121  
   122  	if peer.endpoint == nil {
   123  		return errors.New("no known endpoint for peer")
   124  	}
   125  
   126  	err := peer.device.net.bind.Send(buffers, peer.endpoint)
   127  	if err == nil {
   128  		var totalLen uint64
   129  		for _, b := range buffers {
   130  			totalLen += uint64(len(b))
   131  		}
   132  		peer.txBytes.Add(totalLen)
   133  	}
   134  	return err
   135  }
   136  
   137  func (peer *Peer) String() string {
   138  	// The awful goo that follows is identical to:
   139  	//
   140  	//   base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
   141  	//   abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
   142  	//   return fmt.Sprintf("peer(%s)", abbreviatedKey)
   143  	//
   144  	// except that it is considerably more efficient.
   145  	src := peer.handshake.remoteStatic
   146  	b64 := func(input byte) byte {
   147  		return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
   148  	}
   149  	b := []byte("peer(____…____)")
   150  	const first = len("peer(")
   151  	const second = len("peer(____…")
   152  	b[first+0] = b64((src[0] >> 2) & 63)
   153  	b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
   154  	b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
   155  	b[first+3] = b64(src[2] & 63)
   156  	b[second+0] = b64(src[29] & 63)
   157  	b[second+1] = b64((src[30] >> 2) & 63)
   158  	b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
   159  	b[second+3] = b64((src[31] << 2) & 63)
   160  	return string(b)
   161  }
   162  
   163  func (peer *Peer) Start() {
   164  	// should never start a peer on a closed device
   165  	if peer.device.isClosed() {
   166  		return
   167  	}
   168  
   169  	// prevent simultaneous start/stop operations
   170  	peer.state.Lock()
   171  	defer peer.state.Unlock()
   172  
   173  	if peer.isRunning.Load() {
   174  		return
   175  	}
   176  
   177  	device := peer.device
   178  	device.log.Verbosef("%v - Starting", peer)
   179  
   180  	// reset routine state
   181  	peer.stopping.Wait()
   182  	peer.stopping.Add(2)
   183  
   184  	peer.handshake.mutex.Lock()
   185  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   186  	peer.handshake.mutex.Unlock()
   187  
   188  	peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
   189  
   190  	peer.timersStart()
   191  
   192  	device.flushInboundQueue(peer.queue.inbound)
   193  	device.flushOutboundQueue(peer.queue.outbound)
   194  
   195  	// Use the device batch size, not the bind batch size, as the device size is
   196  	// the size of the batch pools.
   197  	batchSize := peer.device.BatchSize()
   198  	go peer.RoutineSequentialSender(batchSize)
   199  	go peer.RoutineSequentialReceiver(batchSize)
   200  
   201  	peer.isRunning.Store(true)
   202  }
   203  
   204  func (peer *Peer) ZeroAndFlushAll() {
   205  	device := peer.device
   206  
   207  	// clear key pairs
   208  
   209  	keypairs := &peer.keypairs
   210  	keypairs.Lock()
   211  	device.DeleteKeypair(keypairs.previous)
   212  	device.DeleteKeypair(keypairs.current)
   213  	device.DeleteKeypair(keypairs.next.Load())
   214  	keypairs.previous = nil
   215  	keypairs.current = nil
   216  	keypairs.next.Store(nil)
   217  	keypairs.Unlock()
   218  
   219  	// clear handshake state
   220  
   221  	handshake := &peer.handshake
   222  	handshake.mutex.Lock()
   223  	device.indexTable.Delete(handshake.localIndex)
   224  	handshake.Clear()
   225  	handshake.mutex.Unlock()
   226  
   227  	peer.FlushStagedPackets()
   228  }
   229  
   230  func (peer *Peer) ExpireCurrentKeypairs() {
   231  	handshake := &peer.handshake
   232  	handshake.mutex.Lock()
   233  	peer.device.indexTable.Delete(handshake.localIndex)
   234  	handshake.Clear()
   235  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   236  	handshake.mutex.Unlock()
   237  
   238  	keypairs := &peer.keypairs
   239  	keypairs.Lock()
   240  	if keypairs.current != nil {
   241  		keypairs.current.sendNonce.Store(RejectAfterMessages)
   242  	}
   243  	if next := keypairs.next.Load(); next != nil {
   244  		next.sendNonce.Store(RejectAfterMessages)
   245  	}
   246  	keypairs.Unlock()
   247  }
   248  
   249  func (peer *Peer) Stop() {
   250  	peer.state.Lock()
   251  	defer peer.state.Unlock()
   252  
   253  	if !peer.isRunning.Swap(false) {
   254  		return
   255  	}
   256  
   257  	peer.device.log.Verbosef("%v - Stopping", peer)
   258  
   259  	peer.timersStop()
   260  	// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
   261  	peer.queue.inbound.c <- nil
   262  	peer.queue.outbound.c <- nil
   263  	peer.stopping.Wait()
   264  	peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
   265  
   266  	peer.ZeroAndFlushAll()
   267  }
   268  
   269  func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
   270  	if peer.disableRoaming {
   271  		return
   272  	}
   273  	peer.Lock()
   274  	peer.endpoint = endpoint
   275  	peer.Unlock()
   276  }