github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/peer.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"container/list"
    10  	"errors"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/bepass-org/wireguard-go/conn"
    16  )
    17  
    18  type Peer struct {
    19  	isRunning         atomic.Bool
    20  	keypairs          Keypairs
    21  	handshake         Handshake
    22  	device            *Device
    23  	stopping          sync.WaitGroup // routines pending stop
    24  	txBytes           atomic.Uint64  // bytes send to peer (endpoint)
    25  	rxBytes           atomic.Uint64  // bytes received from peer
    26  	lastHandshakeNano atomic.Int64   // nano seconds since epoch
    27  
    28  	endpoint struct {
    29  		sync.Mutex
    30  		val            conn.Endpoint
    31  		clearSrcOnTx   bool // signal to val.ClearSrc() prior to next packet transmission
    32  		disableRoaming bool
    33  	}
    34  
    35  	timers struct {
    36  		retransmitHandshake     *Timer
    37  		sendKeepalive           *Timer
    38  		newHandshake            *Timer
    39  		zeroKeyMaterial         *Timer
    40  		persistentKeepalive     *Timer
    41  		handshakeAttempts       atomic.Uint32
    42  		needAnotherKeepalive    atomic.Bool
    43  		sentLastMinuteHandshake atomic.Bool
    44  	}
    45  
    46  	state struct {
    47  		sync.Mutex // protects against concurrent Start/Stop
    48  	}
    49  
    50  	queue struct {
    51  		staged   chan *QueueOutboundElementsContainer // staged packets before a handshake is available
    52  		outbound *autodrainingOutboundQueue           // sequential ordering of udp transmission
    53  		inbound  *autodrainingInboundQueue            // sequential ordering of tun writing
    54  	}
    55  
    56  	trick  bool
    57  	stopCh chan int
    58  
    59  	cookieGenerator             CookieGenerator
    60  	trieEntries                 list.List
    61  	persistentKeepaliveInterval atomic.Uint32
    62  }
    63  
    64  func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
    65  	if device.isClosed() {
    66  		return nil, errors.New("device closed")
    67  	}
    68  
    69  	// lock resources
    70  	device.staticIdentity.RLock()
    71  	defer device.staticIdentity.RUnlock()
    72  
    73  	device.peers.Lock()
    74  	defer device.peers.Unlock()
    75  
    76  	// check if over limit
    77  	if len(device.peers.keyMap) >= MaxPeers {
    78  		return nil, errors.New("too many peers")
    79  	}
    80  
    81  	// create peer
    82  	peer := new(Peer)
    83  	peer.stopCh = make(chan int, 1)
    84  	peer.trick = true
    85  	peer.cookieGenerator.Init(pk)
    86  	peer.device = device
    87  	peer.queue.outbound = newAutodrainingOutboundQueue(device)
    88  	peer.queue.inbound = newAutodrainingInboundQueue(device)
    89  	peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
    90  
    91  	// map public key
    92  	_, ok := device.peers.keyMap[pk]
    93  	if ok {
    94  		return nil, errors.New("adding existing peer")
    95  	}
    96  
    97  	// pre-compute DH
    98  	handshake := &peer.handshake
    99  	handshake.mutex.Lock()
   100  	handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
   101  	handshake.remoteStatic = pk
   102  	handshake.mutex.Unlock()
   103  
   104  	// reset endpoint
   105  	peer.endpoint.Lock()
   106  	peer.endpoint.val = nil
   107  	peer.endpoint.disableRoaming = false
   108  	peer.endpoint.clearSrcOnTx = false
   109  	peer.endpoint.Unlock()
   110  
   111  	// init timers
   112  	peer.timersInit()
   113  
   114  	// add
   115  	device.peers.keyMap[pk] = peer
   116  
   117  	return peer, nil
   118  }
   119  
   120  func (peer *Peer) SendBuffers(buffers [][]byte) error {
   121  	peer.device.net.RLock()
   122  	defer peer.device.net.RUnlock()
   123  
   124  	if peer.device.isClosed() {
   125  		return nil
   126  	}
   127  
   128  	peer.endpoint.Lock()
   129  	endpoint := peer.endpoint.val
   130  	if endpoint == nil {
   131  		peer.endpoint.Unlock()
   132  		return errors.New("no known endpoint for peer")
   133  	}
   134  	if peer.endpoint.clearSrcOnTx {
   135  		endpoint.ClearSrc()
   136  		peer.endpoint.clearSrcOnTx = false
   137  	}
   138  	peer.endpoint.Unlock()
   139  
   140  	err := peer.device.net.bind.Send(buffers, endpoint)
   141  	if err == nil {
   142  		var totalLen uint64
   143  		for _, b := range buffers {
   144  			totalLen += uint64(len(b))
   145  		}
   146  		peer.txBytes.Add(totalLen)
   147  	}
   148  	return err
   149  }
   150  
   151  func (peer *Peer) String() string {
   152  	// The awful goo that follows is identical to:
   153  	//
   154  	//   base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
   155  	//   abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
   156  	//   return fmt.Sprintf("peer(%s)", abbreviatedKey)
   157  	//
   158  	// except that it is considerably more efficient.
   159  	src := peer.handshake.remoteStatic
   160  	b64 := func(input byte) byte {
   161  		return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
   162  	}
   163  	b := []byte("peer(____…____)")
   164  	const first = len("peer(")
   165  	const second = len("peer(____…")
   166  	b[first+0] = b64((src[0] >> 2) & 63)
   167  	b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
   168  	b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
   169  	b[first+3] = b64(src[2] & 63)
   170  	b[second+0] = b64(src[29] & 63)
   171  	b[second+1] = b64((src[30] >> 2) & 63)
   172  	b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
   173  	b[second+3] = b64((src[31] << 2) & 63)
   174  	return string(b)
   175  }
   176  
   177  func (peer *Peer) Start() {
   178  	// should never start a peer on a closed device
   179  	if peer.device.isClosed() {
   180  		return
   181  	}
   182  
   183  	// prevent simultaneous start/stop operations
   184  	peer.state.Lock()
   185  	defer peer.state.Unlock()
   186  
   187  	if peer.isRunning.Load() {
   188  		return
   189  	}
   190  
   191  	device := peer.device
   192  	device.log.Verbosef("%v - Starting", peer)
   193  
   194  	// reset routine state
   195  	peer.stopping.Wait()
   196  	peer.stopping.Add(2)
   197  
   198  	peer.handshake.mutex.Lock()
   199  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   200  	peer.handshake.mutex.Unlock()
   201  
   202  	peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
   203  
   204  	peer.timersStart()
   205  
   206  	device.flushInboundQueue(peer.queue.inbound)
   207  	device.flushOutboundQueue(peer.queue.outbound)
   208  
   209  	// Use the device batch size, not the bind batch size, as the device size is
   210  	// the size of the batch pools.
   211  	batchSize := peer.device.BatchSize()
   212  	go peer.RoutineSequentialSender(batchSize)
   213  	go peer.RoutineSequentialReceiver(batchSize)
   214  
   215  	peer.isRunning.Store(true)
   216  }
   217  
   218  func (peer *Peer) ZeroAndFlushAll() {
   219  	device := peer.device
   220  
   221  	// clear key pairs
   222  
   223  	keypairs := &peer.keypairs
   224  	keypairs.Lock()
   225  	device.DeleteKeypair(keypairs.previous)
   226  	device.DeleteKeypair(keypairs.current)
   227  	device.DeleteKeypair(keypairs.next.Load())
   228  	keypairs.previous = nil
   229  	keypairs.current = nil
   230  	keypairs.next.Store(nil)
   231  	keypairs.Unlock()
   232  
   233  	// clear handshake state
   234  
   235  	handshake := &peer.handshake
   236  	handshake.mutex.Lock()
   237  	device.indexTable.Delete(handshake.localIndex)
   238  	handshake.Clear()
   239  	handshake.mutex.Unlock()
   240  
   241  	peer.FlushStagedPackets()
   242  }
   243  
   244  func (peer *Peer) ExpireCurrentKeypairs() {
   245  	handshake := &peer.handshake
   246  	handshake.mutex.Lock()
   247  	peer.device.indexTable.Delete(handshake.localIndex)
   248  	handshake.Clear()
   249  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   250  	handshake.mutex.Unlock()
   251  
   252  	keypairs := &peer.keypairs
   253  	keypairs.Lock()
   254  	if keypairs.current != nil {
   255  		keypairs.current.sendNonce.Store(RejectAfterMessages)
   256  	}
   257  	if next := keypairs.next.Load(); next != nil {
   258  		next.sendNonce.Store(RejectAfterMessages)
   259  	}
   260  	keypairs.Unlock()
   261  }
   262  
   263  func (peer *Peer) Stop() {
   264  	peer.state.Lock()
   265  	defer peer.state.Unlock()
   266  
   267  	if !peer.isRunning.Swap(false) {
   268  		return
   269  	}
   270  
   271  	select {
   272  	case peer.stopCh <- 1:
   273  	default:
   274  	}
   275  	peer.device.log.Verbosef("%v - Stopping", peer)
   276  
   277  	peer.timersStop()
   278  	// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
   279  	peer.queue.inbound.c <- nil
   280  	peer.queue.outbound.c <- nil
   281  	peer.stopping.Wait()
   282  	peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
   283  
   284  	peer.ZeroAndFlushAll()
   285  }
   286  
   287  func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
   288  	peer.endpoint.Lock()
   289  	defer peer.endpoint.Unlock()
   290  	if peer.endpoint.disableRoaming {
   291  		return
   292  	}
   293  	peer.endpoint.clearSrcOnTx = false
   294  	peer.endpoint.val = endpoint
   295  }
   296  
   297  func (peer *Peer) markEndpointSrcForClearing() {
   298  	peer.endpoint.Lock()
   299  	defer peer.endpoint.Unlock()
   300  	if peer.endpoint.val == nil {
   301  		return
   302  	}
   303  	peer.endpoint.clearSrcOnTx = true
   304  }