github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/transport.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"fmt"
    36  	"log/slog"
    37  	"runtime"
    38  	"sync"
    39  	"sync/atomic"
    40  	"time"
    41  
    42  	"github.com/noisysockets/noisysockets/internal/conn"
    43  	"github.com/noisysockets/noisysockets/internal/ratelimiter"
    44  	"github.com/noisysockets/noisysockets/types"
    45  )
    46  
    47  type Transport struct {
    48  	logger *slog.Logger
    49  
    50  	state struct {
    51  		// state holds the transport's state. It is accessed atomically.
    52  		// Use the transport.transportState method to read it.
    53  		// transport.transportState does not acquire the mutex, so it captures only a snapshot.
    54  		// During state transitions, the state variable is updated before the transport itself.
    55  		// The state is thus either the current state of the transport or
    56  		// the intended future state of the transport.
    57  		// For example, while executing a call to Up, state will be transportStateUp.
    58  		// There is no guarantee that that intended future state of the transport
    59  		// will become the actual state; Up can fail.
    60  		// The transport can also change state multiple times between time of check and time of use.
    61  		// Unsynchronized uses of state must therefore be advisory/best-effort only.
    62  		state atomic.Uint32 // actually a transportState, but typed uint32 for convenience
    63  		// stopping blocks until all inputs to Transport have been closed.
    64  		stopping sync.WaitGroup
    65  		// mu protects state changes.
    66  		sync.Mutex
    67  	}
    68  
    69  	net struct {
    70  		stopping sync.WaitGroup
    71  		sync.RWMutex
    72  		bind conn.Bind // bind interface
    73  		port uint16    // listening port
    74  	}
    75  
    76  	staticIdentity struct {
    77  		sync.RWMutex
    78  		privateKey types.NoisePrivateKey
    79  		publicKey  types.NoisePublicKey
    80  	}
    81  
    82  	peers struct {
    83  		sync.RWMutex // protects keyMap
    84  		keyMap       map[types.NoisePublicKey]*Peer
    85  	}
    86  
    87  	rate struct {
    88  		underLoadUntil atomic.Int64
    89  		limiter        ratelimiter.Ratelimiter
    90  	}
    91  
    92  	indexTable    IndexTable
    93  	cookieChecker CookieChecker
    94  
    95  	pool struct {
    96  		inboundElementsContainer  *WaitPool
    97  		outboundElementsContainer *WaitPool
    98  		messageBuffers            *WaitPool
    99  		inboundElements           *WaitPool
   100  		outboundElements          *WaitPool
   101  	}
   102  
   103  	queue struct {
   104  		encryption *outboundQueue
   105  		decryption *inboundQueue
   106  		handshake  *handshakeQueue
   107  	}
   108  
   109  	sourceSink SourceSink
   110  
   111  	closed chan struct{}
   112  }
   113  
   114  // transportState represents the state of a Transport.
   115  // There are three states: down, up, closed.
   116  // Transitions:
   117  //
   118  //	down -----+
   119  //	  ↑↓      ↓
   120  //	  up -> closed
   121  type transportState uint32
   122  
   123  const (
   124  	transportStateDown transportState = iota
   125  	transportStateUp
   126  	transportStateClosed
   127  )
   128  
   129  func (state transportState) String() string {
   130  	switch state {
   131  	case transportStateDown:
   132  		return "down"
   133  	case transportStateUp:
   134  		return "up"
   135  	case transportStateClosed:
   136  		return "closed"
   137  	default:
   138  		return "unknown"
   139  	}
   140  }
   141  
   142  // transportState returns transport.state.state as a transportState
   143  // See those docs for how to interpret this value.
   144  func (transport *Transport) transportState() transportState {
   145  	return transportState(transport.state.state.Load())
   146  }
   147  
   148  // isClosed reports whether the transport is closed (or is closing).
   149  // See transport.state.state comments for how to interpret this value.
   150  func (transport *Transport) isClosed() bool {
   151  	return transport.transportState() == transportStateClosed
   152  }
   153  
   154  // isUp reports whether the transport is up (or is attempting to come up).
   155  // See transport.state.state comments for how to interpret this value.
   156  func (transport *Transport) isUp() bool {
   157  	return transport.transportState() == transportStateUp
   158  }
   159  
   160  // Must hold transport.peers.Lock()
   161  func removePeerLocked(transport *Transport, peer *Peer, key types.NoisePublicKey) {
   162  	// stop routing and processing of packets
   163  	peer.Stop()
   164  
   165  	// remove from peer map
   166  	delete(transport.peers.keyMap, key)
   167  }
   168  
   169  // changeState attempts to change the transport state to match want.
   170  func (transport *Transport) changeState(want transportState) (err error) {
   171  	transport.state.Lock()
   172  	defer transport.state.Unlock()
   173  	old := transport.transportState()
   174  	if old == transportStateClosed {
   175  		// once closed, always closed
   176  		transport.logger.Debug("Interface closed, ignored requested state", slog.String("want", want.String()))
   177  		return nil
   178  	}
   179  	switch want {
   180  	case old:
   181  		return nil
   182  	case transportStateUp:
   183  		transport.state.state.Store(uint32(transportStateUp))
   184  		err = transport.upLocked()
   185  		if err == nil {
   186  			break
   187  		}
   188  		fallthrough // up failed; bring the transport all the way back down
   189  	case transportStateDown:
   190  		transport.state.state.Store(uint32(transportStateDown))
   191  		errDown := transport.downLocked()
   192  		if err == nil {
   193  			err = errDown
   194  		}
   195  	}
   196  	transport.logger.Debug("Interface state change requested",
   197  		slog.String("old", old.String()), slog.String("want", want.String()),
   198  		slog.String("now", transport.transportState().String()))
   199  	return
   200  }
   201  
   202  // upLocked attempts to bring the transport up and reports whether it succeeded.
   203  // The caller must hold transport.state.mu and is responsible for updating transport.state.state.
   204  func (transport *Transport) upLocked() error {
   205  	if err := transport.BindUpdate(); err != nil {
   206  		transport.logger.Error("Unable to update bind", slog.Any("error", err))
   207  		return err
   208  	}
   209  
   210  	transport.peers.RLock()
   211  	for _, peer := range transport.peers.keyMap {
   212  		peer.Start()
   213  		if peer.keepAliveInterval.Load() > 0 {
   214  			if err := peer.SendKeepalive(); err != nil {
   215  				transport.logger.Error("Failed to send keepalive",
   216  					slog.String("peer", peer.String()), slog.Any("error", err))
   217  			}
   218  		}
   219  	}
   220  	transport.peers.RUnlock()
   221  	return nil
   222  }
   223  
   224  // downLocked attempts to bring the transport down.
   225  // The caller must hold transport.state.mu and is responsible for updating transport.state.state.
   226  func (transport *Transport) downLocked() error {
   227  	err := transport.BindClose()
   228  	if err != nil {
   229  		transport.logger.Error("Bind close failed", slog.Any("error", err))
   230  	}
   231  
   232  	transport.peers.RLock()
   233  	for _, peer := range transport.peers.keyMap {
   234  		peer.Stop()
   235  	}
   236  	transport.peers.RUnlock()
   237  	return err
   238  }
   239  
   240  func (transport *Transport) Up() error {
   241  	return transport.changeState(transportStateUp)
   242  }
   243  
   244  func (transport *Transport) Down() error {
   245  	return transport.changeState(transportStateDown)
   246  }
   247  
   248  func (transport *Transport) IsUnderLoad() bool {
   249  	// check if currently under load
   250  	now := time.Now()
   251  	underLoad := len(transport.queue.handshake.c) >= QueueHandshakeSize/8
   252  	if underLoad {
   253  		transport.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
   254  		return true
   255  	}
   256  	// check if recently under load
   257  	return transport.rate.underLoadUntil.Load() > now.UnixNano()
   258  }
   259  
   260  func (transport *Transport) SetPrivateKey(sk types.NoisePrivateKey) {
   261  	// lock required resources
   262  
   263  	transport.staticIdentity.Lock()
   264  	defer transport.staticIdentity.Unlock()
   265  
   266  	if sk.Equals(transport.staticIdentity.privateKey) {
   267  		return
   268  	}
   269  
   270  	transport.peers.Lock()
   271  	defer transport.peers.Unlock()
   272  
   273  	lockedPeers := make([]*Peer, 0, len(transport.peers.keyMap))
   274  	for _, peer := range transport.peers.keyMap {
   275  		peer.handshake.mutex.RLock()
   276  		lockedPeers = append(lockedPeers, peer)
   277  	}
   278  
   279  	// remove peers with matching public keys
   280  
   281  	publicKey := sk.Public()
   282  	for key, peer := range transport.peers.keyMap {
   283  		if peer.handshake.remoteStatic.Equals(publicKey) {
   284  			peer.handshake.mutex.RUnlock()
   285  			removePeerLocked(transport, peer, key)
   286  			peer.handshake.mutex.RLock()
   287  		}
   288  	}
   289  
   290  	// update key material
   291  
   292  	transport.staticIdentity.privateKey = sk
   293  	transport.staticIdentity.publicKey = publicKey
   294  	transport.cookieChecker.Init(publicKey)
   295  
   296  	// do static-static DH pre-computations
   297  
   298  	expiredPeers := make([]*Peer, 0, len(transport.peers.keyMap))
   299  	for _, peer := range transport.peers.keyMap {
   300  		handshake := &peer.handshake
   301  		handshake.precomputedStaticStatic, _ = sharedSecret(transport.staticIdentity.privateKey, handshake.remoteStatic)
   302  		expiredPeers = append(expiredPeers, peer)
   303  	}
   304  
   305  	for _, peer := range lockedPeers {
   306  		peer.handshake.mutex.RUnlock()
   307  	}
   308  	for _, peer := range expiredPeers {
   309  		peer.ExpireCurrentKeypairs()
   310  	}
   311  }
   312  
   313  func NewTransport(logger *slog.Logger, sourceSink SourceSink, bind conn.Bind) *Transport {
   314  	t := new(Transport)
   315  	t.state.state.Store(uint32(transportStateDown))
   316  	t.closed = make(chan struct{})
   317  	t.logger = logger
   318  	t.net.bind = bind
   319  	t.sourceSink = sourceSink
   320  	t.peers.keyMap = make(map[types.NoisePublicKey]*Peer)
   321  	t.rate.limiter.Init()
   322  	t.indexTable.Init()
   323  
   324  	t.PopulatePools()
   325  
   326  	// create queues
   327  
   328  	t.queue.handshake = newHandshakeQueue()
   329  	t.queue.encryption = newOutboundQueue()
   330  	t.queue.decryption = newInboundQueue()
   331  
   332  	// start workers
   333  
   334  	cpus := runtime.NumCPU()
   335  	t.state.stopping.Wait()
   336  	t.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
   337  	for i := 0; i < cpus; i++ {
   338  		go t.RoutineEncryption(i + 1)
   339  		go t.RoutineDecryption(i + 1)
   340  		go t.RoutineHandshake(i + 1)
   341  	}
   342  
   343  	t.state.stopping.Add(1)
   344  	t.queue.encryption.wg.Add(1)
   345  	go t.RoutineReadFromSourceSink()
   346  
   347  	return t
   348  }
   349  
   350  // BatchSize returns the BatchSize for the transport as a whole which is the max of
   351  // the bind batch size and the sink batch size. The batch size reported by transport
   352  // is the size used to construct memory pools, and is the allowed batch size for
   353  // the lifetime of the transport.
   354  func (transport *Transport) BatchSize() int {
   355  	size := transport.net.bind.BatchSize()
   356  	dSize := transport.sourceSink.BatchSize()
   357  	if size < dSize {
   358  		size = dSize
   359  	}
   360  	return size
   361  }
   362  
   363  func (transport *Transport) Peers() []types.NoisePublicKey {
   364  	transport.peers.RLock()
   365  	defer transport.peers.RUnlock()
   366  
   367  	keys := make([]types.NoisePublicKey, 0, len(transport.peers.keyMap))
   368  	for key := range transport.peers.keyMap {
   369  		keys = append(keys, key)
   370  	}
   371  	return keys
   372  }
   373  
   374  func (transport *Transport) LookupPeer(pk types.NoisePublicKey) *Peer {
   375  	transport.peers.RLock()
   376  	defer transport.peers.RUnlock()
   377  
   378  	return transport.peers.keyMap[pk]
   379  }
   380  
   381  func (transport *Transport) RemovePeer(pk types.NoisePublicKey) {
   382  	transport.peers.Lock()
   383  	defer transport.peers.Unlock()
   384  	// stop peer and remove from routing
   385  
   386  	peer, ok := transport.peers.keyMap[pk]
   387  	if ok {
   388  		removePeerLocked(transport, peer, pk)
   389  	}
   390  }
   391  
   392  func (transport *Transport) RemoveAllPeers() {
   393  	transport.peers.Lock()
   394  	defer transport.peers.Unlock()
   395  
   396  	for key, peer := range transport.peers.keyMap {
   397  		removePeerLocked(transport, peer, key)
   398  	}
   399  
   400  	transport.peers.keyMap = make(map[types.NoisePublicKey]*Peer)
   401  }
   402  
   403  func (transport *Transport) Close() error {
   404  	transport.state.Lock()
   405  	defer transport.state.Unlock()
   406  	if transport.isClosed() {
   407  		return nil
   408  	}
   409  	transport.state.state.Store(uint32(transportStateClosed))
   410  	transport.logger.Debug("Transport closing")
   411  
   412  	_ = transport.sourceSink.Close()
   413  	_ = transport.downLocked()
   414  
   415  	// Remove peers before closing queues,
   416  	// because peers assume that queues are active.
   417  	transport.RemoveAllPeers()
   418  
   419  	// We kept a reference to the encryption and decryption queues,
   420  	// in case we started any new peers that might write to them.
   421  	// No new peers are coming; we are done with these queues.
   422  	transport.queue.encryption.wg.Done()
   423  	transport.queue.decryption.wg.Done()
   424  	transport.queue.handshake.wg.Done()
   425  	transport.state.stopping.Wait()
   426  
   427  	if err := transport.rate.limiter.Close(); err != nil {
   428  		return fmt.Errorf("failed to close rate limiter: %w", err)
   429  	}
   430  
   431  	transport.logger.Debug("Transport closed")
   432  	close(transport.closed)
   433  
   434  	return nil
   435  }
   436  
   437  func (transport *Transport) Wait() chan struct{} {
   438  	return transport.closed
   439  }
   440  
   441  func (transport *Transport) SendKeepalivesToPeersWithCurrentKeypair() {
   442  	if !transport.isUp() {
   443  		return
   444  	}
   445  
   446  	transport.peers.RLock()
   447  	for _, peer := range transport.peers.keyMap {
   448  		peer.keypairs.RLock()
   449  		sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
   450  		peer.keypairs.RUnlock()
   451  		if sendKeepalive {
   452  			if err := peer.SendKeepalive(); err != nil {
   453  				transport.logger.Error("Failed to send keepalive",
   454  					slog.String("peer", peer.String()), slog.Any("error", err))
   455  			}
   456  		}
   457  	}
   458  	transport.peers.RUnlock()
   459  }
   460  
   461  // closeBindLocked closes the transport's net.bind.
   462  // The caller must hold the net mutex.
   463  func closeBindLocked(transport *Transport) error {
   464  	var err error
   465  	netc := &transport.net
   466  	if netc.bind != nil {
   467  		err = netc.bind.Close()
   468  	}
   469  	netc.stopping.Wait()
   470  	return err
   471  }
   472  
   473  func (transport *Transport) Bind() conn.Bind {
   474  	transport.net.Lock()
   475  	defer transport.net.Unlock()
   476  	return transport.net.bind
   477  }
   478  
   479  func (transport *Transport) BindUpdate() error {
   480  	transport.net.Lock()
   481  	defer transport.net.Unlock()
   482  
   483  	// close existing sockets
   484  	if err := closeBindLocked(transport); err != nil {
   485  		return err
   486  	}
   487  
   488  	// open new sockets
   489  	if !transport.isUp() {
   490  		return nil
   491  	}
   492  
   493  	// bind to new port
   494  	var err error
   495  	var recvFns []conn.ReceiveFunc
   496  	netc := &transport.net
   497  
   498  	recvFns, netc.port, err = netc.bind.Open(netc.port)
   499  	if err != nil {
   500  		netc.port = 0
   501  		return err
   502  	}
   503  
   504  	// start receiving routines
   505  	transport.net.stopping.Add(len(recvFns))
   506  	transport.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to transport.queue.decryption
   507  	transport.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to transport.queue.handshake
   508  	batchSize := netc.bind.BatchSize()
   509  	for _, fn := range recvFns {
   510  		go transport.RoutineReceiveIncoming(batchSize, fn)
   511  	}
   512  
   513  	transport.logger.Debug("UDP bind has been updated")
   514  	return nil
   515  }
   516  
   517  func (transport *Transport) BindClose() error {
   518  	transport.net.Lock()
   519  	err := closeBindLocked(transport)
   520  	transport.net.Unlock()
   521  	return err
   522  }
   523  
   524  func (transport *Transport) UpdatePort(port uint16) error {
   525  	transport.net.Lock()
   526  	transport.net.port = port
   527  	transport.net.Unlock()
   528  
   529  	if err := transport.BindUpdate(); err != nil {
   530  		return fmt.Errorf("failed to update port: %w", err)
   531  	}
   532  
   533  	return nil
   534  }