github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/peer.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"errors"
    36  	"fmt"
    37  	"log/slog"
    38  	"sync"
    39  	"sync/atomic"
    40  	"time"
    41  
    42  	"github.com/noisysockets/noisysockets/internal/conn"
    43  	"github.com/noisysockets/noisysockets/types"
    44  )
    45  
    46  type Peer struct {
    47  	pk                types.NoisePublicKey
    48  	isRunning         atomic.Bool
    49  	keypairs          Keypairs
    50  	handshake         Handshake
    51  	transport         *Transport
    52  	stopping          sync.WaitGroup // routines pending stop
    53  	txBytes           atomic.Uint64  // bytes send to peer (endpoint)
    54  	rxBytes           atomic.Uint64  // bytes received from peer
    55  	lastHandshakeNano atomic.Int64   // nano seconds since epoch
    56  
    57  	endpoint struct {
    58  		sync.Mutex
    59  		val conn.Endpoint
    60  	}
    61  
    62  	timers struct {
    63  		retransmitHandshake     *Timer
    64  		sendKeepalive           *Timer
    65  		newHandshake            *Timer
    66  		zeroKeyMaterial         *Timer
    67  		keepAlive               *Timer
    68  		handshakeAttempts       atomic.Uint32
    69  		needAnotherKeepalive    atomic.Bool
    70  		sentLastMinuteHandshake atomic.Bool
    71  	}
    72  
    73  	state struct {
    74  		sync.Mutex // protects against concurrent Start/Stop
    75  	}
    76  
    77  	queue struct {
    78  		staged   chan *QueueOutboundElementsContainer // staged packets before a handshake is available
    79  		outbound *autodrainingOutboundQueue           // sequential ordering of udp transmission
    80  		inbound  *autodrainingInboundQueue            // sequential ordering of sink writing
    81  	}
    82  
    83  	cookieGenerator   CookieGenerator
    84  	keepAliveInterval atomic.Uint32
    85  }
    86  
    87  func (transport *Transport) NewPeer(pk types.NoisePublicKey) (*Peer, error) {
    88  	if transport.isClosed() {
    89  		return nil, errors.New("transport closed")
    90  	}
    91  
    92  	// lock resources
    93  	transport.staticIdentity.RLock()
    94  	defer transport.staticIdentity.RUnlock()
    95  
    96  	transport.peers.Lock()
    97  	defer transport.peers.Unlock()
    98  
    99  	// check if over limit
   100  	if len(transport.peers.keyMap) >= MaxPeers {
   101  		return nil, errors.New("too many peers")
   102  	}
   103  
   104  	// create peer
   105  	peer := new(Peer)
   106  
   107  	peer.pk = pk
   108  	peer.cookieGenerator.Init(pk)
   109  	peer.transport = transport
   110  	peer.queue.outbound = newAutodrainingOutboundQueue(transport)
   111  	peer.queue.inbound = newAutodrainingInboundQueue(transport)
   112  	peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
   113  
   114  	// map public key
   115  	_, ok := transport.peers.keyMap[pk]
   116  	if ok {
   117  		return nil, errors.New("adding existing peer")
   118  	}
   119  
   120  	// pre-compute DH
   121  	handshake := &peer.handshake
   122  	handshake.mutex.Lock()
   123  	handshake.precomputedStaticStatic, _ = sharedSecret(transport.staticIdentity.privateKey, pk)
   124  	handshake.remoteStatic = pk
   125  	handshake.mutex.Unlock()
   126  
   127  	// reset endpoint
   128  	peer.endpoint.Lock()
   129  	peer.endpoint.val = nil
   130  	peer.endpoint.Unlock()
   131  
   132  	// init timers
   133  	peer.timersInit()
   134  
   135  	// add
   136  	transport.peers.keyMap[pk] = peer
   137  
   138  	return peer, nil
   139  }
   140  
   141  func (peer *Peer) SendBuffers(buffers [][]byte) error {
   142  	peer.transport.net.RLock()
   143  	defer peer.transport.net.RUnlock()
   144  
   145  	if peer.transport.isClosed() {
   146  		return nil
   147  	}
   148  
   149  	peer.endpoint.Lock()
   150  	endpoint := peer.endpoint.val
   151  	if endpoint == nil {
   152  		peer.endpoint.Unlock()
   153  		return errors.New("no known endpoint for peer")
   154  	}
   155  	peer.endpoint.Unlock()
   156  
   157  	err := peer.transport.net.bind.Send(buffers, endpoint)
   158  	if err == nil {
   159  		var totalLen uint64
   160  		for _, b := range buffers {
   161  			totalLen += uint64(len(b))
   162  		}
   163  		peer.txBytes.Add(totalLen)
   164  	}
   165  	return err
   166  }
   167  
   168  func (peer *Peer) String() string {
   169  	return fmt.Sprintf("peer(%s)", peer.handshake.remoteStatic.DisplayString())
   170  }
   171  
   172  func (peer *Peer) Start() {
   173  	// should never start a peer on a closed transport.
   174  	if peer.transport.isClosed() {
   175  		return
   176  	}
   177  
   178  	// prevent simultaneous start/stop operations
   179  	peer.state.Lock()
   180  	defer peer.state.Unlock()
   181  
   182  	if peer.isRunning.Load() {
   183  		return
   184  	}
   185  
   186  	transport := peer.transport
   187  	transport.logger.Debug("Starting", slog.String("peer", peer.String()))
   188  
   189  	// reset routine state
   190  	peer.stopping.Wait()
   191  	peer.stopping.Add(2)
   192  
   193  	peer.handshake.mutex.Lock()
   194  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   195  	peer.handshake.mutex.Unlock()
   196  
   197  	peer.transport.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
   198  
   199  	peer.timersStart()
   200  
   201  	transport.flushInboundQueue(peer.queue.inbound)
   202  	transport.flushOutboundQueue(peer.queue.outbound)
   203  
   204  	// Use the transport batch size, not the bind batch size, as the transport size is
   205  	// the size of the batch pools.
   206  	batchSize := peer.transport.BatchSize()
   207  	go peer.RoutineSequentialSender(batchSize)
   208  	go peer.RoutineSequentialReceiver(batchSize)
   209  
   210  	peer.isRunning.Store(true)
   211  }
   212  
   213  func (peer *Peer) ZeroAndFlushAll() {
   214  	transport := peer.transport
   215  
   216  	// clear key pairs
   217  
   218  	keypairs := &peer.keypairs
   219  	keypairs.Lock()
   220  	transport.DeleteKeypair(keypairs.previous)
   221  	transport.DeleteKeypair(keypairs.current)
   222  	transport.DeleteKeypair(keypairs.next.Load())
   223  	keypairs.previous = nil
   224  	keypairs.current = nil
   225  	keypairs.next.Store(nil)
   226  	keypairs.Unlock()
   227  
   228  	// clear handshake state
   229  
   230  	handshake := &peer.handshake
   231  	handshake.mutex.Lock()
   232  	transport.indexTable.Delete(handshake.localIndex)
   233  	handshake.Clear()
   234  	handshake.mutex.Unlock()
   235  
   236  	peer.FlushStagedPackets()
   237  }
   238  
   239  func (peer *Peer) ExpireCurrentKeypairs() {
   240  	handshake := &peer.handshake
   241  	handshake.mutex.Lock()
   242  	peer.transport.indexTable.Delete(handshake.localIndex)
   243  	handshake.Clear()
   244  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
   245  	handshake.mutex.Unlock()
   246  
   247  	keypairs := &peer.keypairs
   248  	keypairs.Lock()
   249  	if keypairs.current != nil {
   250  		keypairs.current.sendNonce.Store(RejectAfterMessages)
   251  	}
   252  	if next := keypairs.next.Load(); next != nil {
   253  		next.sendNonce.Store(RejectAfterMessages)
   254  	}
   255  	keypairs.Unlock()
   256  }
   257  
   258  func (peer *Peer) Stop() {
   259  	peer.state.Lock()
   260  	defer peer.state.Unlock()
   261  
   262  	if !peer.isRunning.Swap(false) {
   263  		return
   264  	}
   265  
   266  	peer.transport.logger.Debug("Stopping", slog.String("peer", peer.String()))
   267  
   268  	peer.timersStop()
   269  	// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
   270  	peer.queue.inbound.c <- nil
   271  	peer.queue.outbound.c <- nil
   272  	peer.stopping.Wait()
   273  	peer.transport.queue.encryption.wg.Done() // no more writes to encryption queue from us
   274  
   275  	peer.ZeroAndFlushAll()
   276  }
   277  
   278  func (peer *Peer) GetEndpoint() conn.Endpoint {
   279  	peer.endpoint.Lock()
   280  	defer peer.endpoint.Unlock()
   281  	return peer.endpoint.val
   282  }
   283  
   284  func (peer *Peer) SetEndpoint(endpoint conn.Endpoint) {
   285  	peer.endpoint.Lock()
   286  	defer peer.endpoint.Unlock()
   287  	peer.endpoint.val = endpoint
   288  }
   289  
   290  func (peer *Peer) SetKeepAliveInterval(interval time.Duration) {
   291  	peer.keepAliveInterval.Store(uint32(interval.Seconds()))
   292  }