github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/send.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"bytes"
    36  	"encoding/binary"
    37  	"errors"
    38  	"log/slog"
    39  	"os"
    40  	"sync"
    41  	"time"
    42  
    43  	"github.com/noisysockets/noisysockets/internal/conn"
    44  	"github.com/noisysockets/noisysockets/types"
    45  	"golang.org/x/crypto/chacha20poly1305"
    46  )
    47  
    48  const DefaultMTU = 1420
    49  
    50  /* Outbound flow
    51   *
    52   * 1. Source queue
    53   * 2. Routing (sequential)
    54   * 3. Nonce assignment (sequential)
    55   * 4. Encryption (parallel)
    56   * 5. Transmission (sequential)
    57   *
    58   * The functions in this file occur (roughly) in the order in
    59   * which the packets are processed.
    60   *
    61   * Locking, Producers and Consumers
    62   *
    63   * The order of packets (per peer) must be maintained,
    64   * but encryption of packets happen out-of-order:
    65   *
    66   * The sequential consumers will attempt to take the lock,
    67   * workers release lock when they have completed work (encryption) on the packet.
    68   *
    69   * If the element is inserted into the "encryption queue",
    70   * the content is preceded by enough "junk" to contain the transport header
    71   * (to allow the construction of transport messages in-place)
    72   */
    73  
    74  type QueueOutboundElement struct {
    75  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    76  	packet  []byte                // slice of "buffer" (always!)
    77  	nonce   uint64                // nonce for encryption
    78  	keypair *Keypair              // keypair for encryption
    79  	peer    *Peer                 // related peer
    80  }
    81  
    82  type QueueOutboundElementsContainer struct {
    83  	sync.Mutex
    84  	elems []*QueueOutboundElement
    85  }
    86  
    87  func (transport *Transport) NewOutboundElement() *QueueOutboundElement {
    88  	elem := transport.GetOutboundElement()
    89  	elem.buffer = transport.GetMessageBuffer()
    90  	elem.nonce = 0
    91  	// keypair and peer were cleared (if necessary) by clearPointers.
    92  	return elem
    93  }
    94  
    95  // clearPointers clears elem fields that contain pointers.
    96  // This makes the garbage collector's life easier and
    97  // avoids accidentally keeping other objects around unnecessarily.
    98  // It also reduces the possible collateral damage from use-after-free bugs.
    99  func (elem *QueueOutboundElement) clearPointers() {
   100  	elem.buffer = nil
   101  	elem.packet = nil
   102  	elem.keypair = nil
   103  	elem.peer = nil
   104  }
   105  
   106  /* Queues a keepalive if no packets are queued for peer
   107   */
   108  func (peer *Peer) SendKeepalive() error {
   109  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
   110  		elem := peer.transport.NewOutboundElement()
   111  		elemsContainer := peer.transport.GetOutboundElementsContainer()
   112  		elemsContainer.elems = append(elemsContainer.elems, elem)
   113  		select {
   114  		case peer.queue.staged <- elemsContainer:
   115  			peer.transport.logger.Debug("Sending keepalive packet", slog.String("peer", peer.String()))
   116  		default:
   117  			peer.transport.PutMessageBuffer(elem.buffer)
   118  			peer.transport.PutOutboundElement(elem)
   119  			peer.transport.PutOutboundElementsContainer(elemsContainer)
   120  		}
   121  	}
   122  	return peer.SendStagedPackets()
   123  }
   124  
   125  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   126  	logger := peer.transport.logger.With(slog.String("peer", peer.String()))
   127  
   128  	peer.endpoint.Lock()
   129  	endpoint := peer.endpoint.val
   130  	peer.endpoint.Unlock()
   131  
   132  	// If we don't have an endpoint, ignore the request.
   133  	if endpoint == nil {
   134  		return nil
   135  	}
   136  
   137  	if !isRetry {
   138  		peer.timers.handshakeAttempts.Store(0)
   139  	}
   140  
   141  	peer.handshake.mutex.RLock()
   142  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   143  		peer.handshake.mutex.RUnlock()
   144  		return nil
   145  	}
   146  	peer.handshake.mutex.RUnlock()
   147  
   148  	peer.handshake.mutex.Lock()
   149  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   150  		peer.handshake.mutex.Unlock()
   151  		return nil
   152  	}
   153  	peer.handshake.lastSentHandshake = time.Now()
   154  	peer.handshake.mutex.Unlock()
   155  
   156  	logger.Debug("Sending handshake initiation")
   157  
   158  	msg, err := peer.transport.CreateMessageInitiation(peer)
   159  	if err != nil {
   160  		logger.Error("Failed to create initiation message", slog.Any("error", err))
   161  		return err
   162  	}
   163  
   164  	var buf [MessageInitiationSize]byte
   165  	writer := bytes.NewBuffer(buf[:0])
   166  	if err := binary.Write(writer, binary.LittleEndian, msg); err != nil {
   167  		logger.Error("Failed to write initiation message", slog.Any("error", err))
   168  		return err
   169  	}
   170  
   171  	packet := writer.Bytes()
   172  	peer.cookieGenerator.AddMacs(packet)
   173  
   174  	peer.timersAnyAuthenticatedPacketTraversal()
   175  	peer.timersAnyAuthenticatedPacketSent()
   176  
   177  	err = peer.SendBuffers([][]byte{packet})
   178  	if err != nil {
   179  		logger.Error("Failed to send handshake initiation", slog.Any("error", err))
   180  	}
   181  	peer.timersHandshakeInitiated()
   182  
   183  	return err
   184  }
   185  
   186  func (peer *Peer) SendHandshakeResponse() error {
   187  	logger := peer.transport.logger.With(slog.String("peer", peer.String()))
   188  
   189  	peer.handshake.mutex.Lock()
   190  	peer.handshake.lastSentHandshake = time.Now()
   191  	peer.handshake.mutex.Unlock()
   192  
   193  	logger.Debug("Sending handshake response")
   194  
   195  	response, err := peer.transport.CreateMessageResponse(peer)
   196  	if err != nil {
   197  		logger.Error("Failed to create handshake response message", slog.Any("error", err))
   198  		return err
   199  	}
   200  
   201  	var buf [MessageResponseSize]byte
   202  	writer := bytes.NewBuffer(buf[:0])
   203  	if err := binary.Write(writer, binary.LittleEndian, response); err != nil {
   204  		logger.Error("Failed to write handshake response message", slog.Any("error", err))
   205  		return err
   206  	}
   207  
   208  	packet := writer.Bytes()
   209  	peer.cookieGenerator.AddMacs(packet)
   210  
   211  	err = peer.BeginSymmetricSession()
   212  	if err != nil {
   213  		logger.Error("Failed to derive keypair", slog.Any("error", err))
   214  		return err
   215  	}
   216  
   217  	peer.timersSessionDerived()
   218  	peer.timersAnyAuthenticatedPacketTraversal()
   219  	peer.timersAnyAuthenticatedPacketSent()
   220  
   221  	// TODO: allocation could be avoided
   222  	err = peer.SendBuffers([][]byte{packet})
   223  	if err != nil {
   224  		logger.Error("Failed to send handshake response", slog.Any("error", err))
   225  	}
   226  	return err
   227  }
   228  
   229  func (transport *Transport) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   230  	logger := transport.logger.With(slog.String("source", initiatingElem.endpoint.DstToString()))
   231  
   232  	logger.Debug("Sending cookie response for denied handshake message")
   233  
   234  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   235  	reply, err := transport.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   236  	if err != nil {
   237  		logger.Error("Failed to create cookie reply", slog.Any("error", err))
   238  		return err
   239  	}
   240  
   241  	var buf [MessageCookieReplySize]byte
   242  	writer := bytes.NewBuffer(buf[:0])
   243  	if err := binary.Write(writer, binary.LittleEndian, reply); err != nil {
   244  		logger.Error("Failed to write cookie reply", slog.Any("error", err))
   245  		return err
   246  	}
   247  
   248  	// TODO: allocation could be avoided
   249  	return transport.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   250  }
   251  
   252  func (peer *Peer) keepKeyFreshSending() error {
   253  	keypair := peer.keypairs.Current()
   254  	if keypair == nil {
   255  		return nil
   256  	}
   257  
   258  	nonce := keypair.sendNonce.Load()
   259  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   260  		return peer.SendHandshakeInitiation(false)
   261  	}
   262  
   263  	return nil
   264  }
   265  
   266  func (transport *Transport) RoutineReadFromSourceSink() {
   267  	defer func() {
   268  		transport.logger.Debug("Routine: Source reader - stopped")
   269  		transport.state.stopping.Done()
   270  		transport.queue.encryption.wg.Done()
   271  	}()
   272  
   273  	transport.logger.Debug("Routine: Source reader - started")
   274  
   275  	var (
   276  		batchSize   = transport.BatchSize()
   277  		readErr     error
   278  		elems       = make([]*QueueOutboundElement, batchSize)
   279  		bufs        = make([][]byte, batchSize)
   280  		peers       = make([]types.NoisePublicKey, batchSize)
   281  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   282  		count       int
   283  		sizes       = make([]int, batchSize)
   284  		offset      = MessageTransportHeaderSize
   285  	)
   286  
   287  	for i := range elems {
   288  		elems[i] = transport.NewOutboundElement()
   289  		bufs[i] = elems[i].buffer[:]
   290  	}
   291  
   292  	defer func() {
   293  		for _, elem := range elems {
   294  			if elem != nil {
   295  				transport.PutMessageBuffer(elem.buffer)
   296  				transport.PutOutboundElement(elem)
   297  			}
   298  		}
   299  	}()
   300  
   301  	for {
   302  		// read packets
   303  		count, readErr = transport.sourceSink.Read(bufs, sizes, peers, offset)
   304  		for i := 0; i < count; i++ {
   305  			if sizes[i] < 1 {
   306  				continue
   307  			}
   308  
   309  			elem := elems[i]
   310  			elem.packet = bufs[i][offset : offset+sizes[i]]
   311  
   312  			transport.peers.RLock()
   313  			peer := transport.peers.keyMap[peers[i]]
   314  			transport.peers.RUnlock()
   315  			if peer == nil {
   316  				continue
   317  			}
   318  
   319  			elemsForPeer, ok := elemsByPeer[peer]
   320  			if !ok {
   321  				elemsForPeer = transport.GetOutboundElementsContainer()
   322  				elemsByPeer[peer] = elemsForPeer
   323  			}
   324  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   325  			elems[i] = transport.NewOutboundElement()
   326  			bufs[i] = elems[i].buffer[:]
   327  		}
   328  
   329  		for peer, elemsForPeer := range elemsByPeer {
   330  			if peer.isRunning.Load() {
   331  				peer.StagePackets(elemsForPeer)
   332  				if err := peer.SendStagedPackets(); err != nil {
   333  					transport.logger.Warn("Failed to send staged packets",
   334  						slog.String("peer", peer.String()), slog.Any("error", err))
   335  					continue
   336  				}
   337  			} else {
   338  				for _, elem := range elemsForPeer.elems {
   339  					transport.PutMessageBuffer(elem.buffer)
   340  					transport.PutOutboundElement(elem)
   341  				}
   342  				transport.PutOutboundElementsContainer(elemsForPeer)
   343  			}
   344  			delete(elemsByPeer, peer)
   345  		}
   346  
   347  		if readErr != nil {
   348  			if !transport.isClosed() {
   349  				if !errors.Is(readErr, os.ErrClosed) {
   350  					transport.logger.Error("Failed to read packet from source sink",
   351  						slog.Any("error", readErr))
   352  				}
   353  				go transport.Close()
   354  			}
   355  			return
   356  		}
   357  	}
   358  }
   359  
   360  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   361  	for {
   362  		select {
   363  		case peer.queue.staged <- elems:
   364  			return
   365  		default:
   366  		}
   367  		select {
   368  		case tooOld := <-peer.queue.staged:
   369  			for _, elem := range tooOld.elems {
   370  				peer.transport.PutMessageBuffer(elem.buffer)
   371  				peer.transport.PutOutboundElement(elem)
   372  			}
   373  			peer.transport.PutOutboundElementsContainer(tooOld)
   374  		default:
   375  		}
   376  	}
   377  }
   378  
   379  func (peer *Peer) SendStagedPackets() error {
   380  top:
   381  	if len(peer.queue.staged) == 0 || !peer.transport.isUp() {
   382  		return nil
   383  	}
   384  
   385  	keypair := peer.keypairs.Current()
   386  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   387  		return peer.SendHandshakeInitiation(false)
   388  	}
   389  
   390  	for {
   391  		var elemsContainerOOO *QueueOutboundElementsContainer
   392  		select {
   393  		case elemsContainer := <-peer.queue.staged:
   394  			i := 0
   395  			for _, elem := range elemsContainer.elems {
   396  				elem.peer = peer
   397  				elem.nonce = keypair.sendNonce.Add(1) - 1
   398  				if elem.nonce >= RejectAfterMessages {
   399  					keypair.sendNonce.Store(RejectAfterMessages)
   400  					if elemsContainerOOO == nil {
   401  						elemsContainerOOO = peer.transport.GetOutboundElementsContainer()
   402  					}
   403  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   404  					continue
   405  				} else {
   406  					elemsContainer.elems[i] = elem
   407  					i++
   408  				}
   409  
   410  				elem.keypair = keypair
   411  			}
   412  			elemsContainer.Lock()
   413  			elemsContainer.elems = elemsContainer.elems[:i]
   414  
   415  			if elemsContainerOOO != nil {
   416  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   417  			}
   418  
   419  			if len(elemsContainer.elems) == 0 {
   420  				peer.transport.PutOutboundElementsContainer(elemsContainer)
   421  				goto top
   422  			}
   423  
   424  			// add to parallel and sequential queue
   425  			if peer.isRunning.Load() {
   426  				peer.queue.outbound.c <- elemsContainer
   427  				peer.transport.queue.encryption.c <- elemsContainer
   428  			} else {
   429  				for _, elem := range elemsContainer.elems {
   430  					peer.transport.PutMessageBuffer(elem.buffer)
   431  					peer.transport.PutOutboundElement(elem)
   432  				}
   433  				peer.transport.PutOutboundElementsContainer(elemsContainer)
   434  			}
   435  
   436  			if elemsContainerOOO != nil {
   437  				goto top
   438  			}
   439  		default:
   440  			return nil
   441  		}
   442  	}
   443  }
   444  
   445  func (peer *Peer) FlushStagedPackets() {
   446  	for {
   447  		select {
   448  		case elemsContainer := <-peer.queue.staged:
   449  			for _, elem := range elemsContainer.elems {
   450  				peer.transport.PutMessageBuffer(elem.buffer)
   451  				peer.transport.PutOutboundElement(elem)
   452  			}
   453  			peer.transport.PutOutboundElementsContainer(elemsContainer)
   454  		default:
   455  			return
   456  		}
   457  	}
   458  }
   459  
   460  func calculatePaddingSize(packetSize, mtu int) int {
   461  	lastUnit := packetSize
   462  	if mtu == 0 {
   463  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   464  	}
   465  	if lastUnit > mtu {
   466  		lastUnit %= mtu
   467  	}
   468  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   469  	if paddedSize > mtu {
   470  		paddedSize = mtu
   471  	}
   472  	return paddedSize - lastUnit
   473  }
   474  
   475  /* Encrypts the elements in the queue
   476   * and marks them for sequential consumption (by releasing the mutex)
   477   *
   478   * Obs. One instance per core
   479   */
   480  func (transport *Transport) RoutineEncryption(id int) {
   481  	var paddingZeros [PaddingMultiple]byte
   482  	var nonce [chacha20poly1305.NonceSize]byte
   483  
   484  	logger := transport.logger.With(slog.Int("id", id))
   485  
   486  	defer logger.Debug("Routine: encryption worker - stopped")
   487  	logger.Debug("Routine: encryption worker - started")
   488  
   489  	for elemsContainer := range transport.queue.encryption.c {
   490  		for _, elem := range elemsContainer.elems {
   491  			// populate header fields
   492  			header := elem.buffer[:MessageTransportHeaderSize]
   493  
   494  			fieldType := header[0:4]
   495  			fieldReceiver := header[4:8]
   496  			fieldNonce := header[8:16]
   497  
   498  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   499  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   500  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   501  
   502  			// pad content to multiple of 16 bytes
   503  			paddingSize := calculatePaddingSize(len(elem.packet), transport.sourceSink.MTU())
   504  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   505  
   506  			// encrypt content and release to consumer
   507  
   508  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   509  			elem.packet = elem.keypair.send.Seal(
   510  				header,
   511  				nonce[:],
   512  				elem.packet,
   513  				nil,
   514  			)
   515  		}
   516  		elemsContainer.Unlock()
   517  	}
   518  }
   519  
   520  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   521  	logger := peer.transport.logger.With(slog.String("peer", peer.String()))
   522  
   523  	transport := peer.transport
   524  	defer func() {
   525  		defer logger.Debug("Routine: sequential sender - stopped")
   526  		peer.stopping.Done()
   527  	}()
   528  	logger.Debug("Routine: sequential sender - started")
   529  
   530  	bufs := make([][]byte, 0, maxBatchSize)
   531  
   532  	for elemsContainer := range peer.queue.outbound.c {
   533  		bufs = bufs[:0]
   534  		if elemsContainer == nil {
   535  			return
   536  		}
   537  		if !peer.isRunning.Load() {
   538  			// peer has been stopped; return re-usable elems to the shared pool.
   539  			// This is an optimization only. It is possible for the peer to be stopped
   540  			// immediately after this check, in which case, elem will get processed.
   541  			// The timers and SendBuffers code are resilient to a few stragglers.
   542  			// TODO: rework peer shutdown order to ensure
   543  			// that we never accidentally keep timers alive longer than necessary.
   544  			elemsContainer.Lock()
   545  			for _, elem := range elemsContainer.elems {
   546  				transport.PutMessageBuffer(elem.buffer)
   547  				transport.PutOutboundElement(elem)
   548  			}
   549  			continue
   550  		}
   551  		dataSent := false
   552  		elemsContainer.Lock()
   553  		for _, elem := range elemsContainer.elems {
   554  			if len(elem.packet) != MessageKeepaliveSize {
   555  				dataSent = true
   556  			}
   557  			bufs = append(bufs, elem.packet)
   558  		}
   559  
   560  		peer.timersAnyAuthenticatedPacketTraversal()
   561  		peer.timersAnyAuthenticatedPacketSent()
   562  
   563  		err := peer.SendBuffers(bufs)
   564  		if dataSent {
   565  			peer.timersDataSent()
   566  		}
   567  		for _, elem := range elemsContainer.elems {
   568  			transport.PutMessageBuffer(elem.buffer)
   569  			transport.PutOutboundElement(elem)
   570  		}
   571  		transport.PutOutboundElementsContainer(elemsContainer)
   572  		if err != nil {
   573  			var errGSO conn.ErrUDPGSODisabled
   574  			if errors.As(err, &errGSO) {
   575  				logger.Warn("Failed to send data packets, retrying", slog.Any("error", err))
   576  				err = errGSO.RetryErr
   577  			}
   578  		}
   579  		if err != nil {
   580  			logger.Error("Failed to send data packets", slog.Any("error", err))
   581  			continue
   582  		}
   583  
   584  		if err := peer.keepKeyFreshSending(); err != nil {
   585  			logger.Error("Failed to keep key fresh", slog.Any("error", err))
   586  		}
   587  	}
   588  }