github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/receive.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"bytes"
    36  	"encoding/binary"
    37  	"errors"
    38  	"log/slog"
    39  	"net"
    40  	"sync"
    41  	"time"
    42  
    43  	"github.com/noisysockets/noisysockets/internal/conn"
    44  	"github.com/noisysockets/noisysockets/types"
    45  	"golang.org/x/crypto/chacha20poly1305"
    46  )
    47  
    48  type QueueHandshakeElement struct {
    49  	msgType  uint32
    50  	packet   []byte
    51  	endpoint conn.Endpoint
    52  	buffer   *[MaxMessageSize]byte
    53  }
    54  
    55  type QueueInboundElement struct {
    56  	buffer   *[MaxMessageSize]byte
    57  	packet   []byte
    58  	counter  uint64
    59  	keypair  *Keypair
    60  	endpoint conn.Endpoint
    61  }
    62  
    63  type QueueInboundElementsContainer struct {
    64  	sync.Mutex
    65  	elems []*QueueInboundElement
    66  }
    67  
    68  // clearPointers clears elem fields that contain pointers.
    69  // This makes the garbage collector's life easier and
    70  // avoids accidentally keeping other objects around unnecessarily.
    71  // It also reduces the possible collateral damage from use-after-free bugs.
    72  func (elem *QueueInboundElement) clearPointers() {
    73  	elem.buffer = nil
    74  	elem.packet = nil
    75  	elem.keypair = nil
    76  	elem.endpoint = nil
    77  }
    78  
    79  /* Called when a new authenticated message has been received
    80   *
    81   * NOTE: Not thread safe, but called by sequential receiver!
    82   */
    83  func (peer *Peer) keepKeyFreshReceiving() error {
    84  	if peer.timers.sentLastMinuteHandshake.Load() {
    85  		return nil
    86  	}
    87  
    88  	keypair := peer.keypairs.Current()
    89  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    90  		peer.timers.sentLastMinuteHandshake.Store(true)
    91  		if err := peer.SendHandshakeInitiation(false); err != nil {
    92  			return err
    93  		}
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  /* Receives incoming datagrams for the transport
   100   *
   101   * Every time the bind is updated a new routine is started for
   102   * IPv4 and IPv6 (separately)
   103   */
   104  func (transport *Transport) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
   105  	recvName := recv.PrettyName()
   106  	defer func() {
   107  		transport.logger.Debug("Routine: receive incoming - stopped", slog.String("recvName", recvName))
   108  		transport.queue.decryption.wg.Done()
   109  		transport.queue.handshake.wg.Done()
   110  		transport.net.stopping.Done()
   111  	}()
   112  
   113  	transport.logger.Debug("Routine: receive incoming - started", slog.String("recvName", recvName))
   114  
   115  	// receive datagrams until conn is closed
   116  
   117  	var (
   118  		bufsArrs    = make([]*[MaxMessageSize]byte, maxBatchSize)
   119  		bufs        = make([][]byte, maxBatchSize)
   120  		err         error
   121  		sizes       = make([]int, maxBatchSize)
   122  		count       int
   123  		endpoints   = make([]conn.Endpoint, maxBatchSize)
   124  		deathSpiral int
   125  		elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
   126  	)
   127  
   128  	for i := range bufsArrs {
   129  		bufsArrs[i] = transport.GetMessageBuffer()
   130  		bufs[i] = bufsArrs[i][:]
   131  	}
   132  
   133  	defer func() {
   134  		for i := 0; i < maxBatchSize; i++ {
   135  			if bufsArrs[i] != nil {
   136  				transport.PutMessageBuffer(bufsArrs[i])
   137  			}
   138  		}
   139  	}()
   140  
   141  	for {
   142  		count, err = recv(bufs, sizes, endpoints)
   143  		if err != nil {
   144  			if errors.Is(err, net.ErrClosed) {
   145  				return
   146  			}
   147  			transport.logger.Warn("Failed to receive packet",
   148  				slog.String("recvName", recvName),
   149  				slog.Any("error", err))
   150  			if deathSpiral < 10 {
   151  				deathSpiral++
   152  				time.Sleep(time.Second / 3)
   153  				continue
   154  			}
   155  			return
   156  		}
   157  		deathSpiral = 0
   158  
   159  		// handle each packet in the batch
   160  		for i, size := range sizes[:count] {
   161  			if size < MinMessageSize {
   162  				continue
   163  			}
   164  
   165  			// check size of packet
   166  
   167  			packet := bufsArrs[i][:size]
   168  			msgType := binary.LittleEndian.Uint32(packet[:4])
   169  
   170  			switch msgType {
   171  
   172  			// check if transport
   173  
   174  			case MessageTransportType:
   175  
   176  				// check size
   177  
   178  				if len(packet) < MessageTransportSize {
   179  					continue
   180  				}
   181  
   182  				// lookup key pair
   183  
   184  				receiver := binary.LittleEndian.Uint32(
   185  					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   186  				)
   187  				value := transport.indexTable.Lookup(receiver)
   188  				keypair := value.keypair
   189  				if keypair == nil {
   190  					continue
   191  				}
   192  
   193  				// check keypair expiry
   194  
   195  				if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   196  					continue
   197  				}
   198  
   199  				// create work element
   200  				peer := value.peer
   201  				elem := transport.GetInboundElement()
   202  				elem.packet = packet
   203  				elem.buffer = bufsArrs[i]
   204  				elem.keypair = keypair
   205  				elem.endpoint = endpoints[i]
   206  				elem.counter = 0
   207  
   208  				elemsForPeer, ok := elemsByPeer[peer]
   209  				if !ok {
   210  					elemsForPeer = transport.GetInboundElementsContainer()
   211  					elemsForPeer.Lock()
   212  					elemsByPeer[peer] = elemsForPeer
   213  				}
   214  				elemsForPeer.elems = append(elemsForPeer.elems, elem)
   215  				bufsArrs[i] = transport.GetMessageBuffer()
   216  				bufs[i] = bufsArrs[i][:]
   217  				continue
   218  
   219  			// otherwise it is a fixed size & handshake related packet
   220  
   221  			case MessageInitiationType:
   222  				if len(packet) != MessageInitiationSize {
   223  					continue
   224  				}
   225  
   226  			case MessageResponseType:
   227  				if len(packet) != MessageResponseSize {
   228  					continue
   229  				}
   230  
   231  			case MessageCookieReplyType:
   232  				if len(packet) != MessageCookieReplySize {
   233  					continue
   234  				}
   235  
   236  			default:
   237  				transport.logger.Warn("Received message with unknown type",
   238  					slog.Int("type", int(msgType)))
   239  				continue
   240  			}
   241  
   242  			select {
   243  			case transport.queue.handshake.c <- QueueHandshakeElement{
   244  				msgType:  msgType,
   245  				buffer:   bufsArrs[i],
   246  				packet:   packet,
   247  				endpoint: endpoints[i],
   248  			}:
   249  				bufsArrs[i] = transport.GetMessageBuffer()
   250  				bufs[i] = bufsArrs[i][:]
   251  			default:
   252  			}
   253  		}
   254  		for peer, elemsContainer := range elemsByPeer {
   255  			if peer.isRunning.Load() {
   256  				peer.queue.inbound.c <- elemsContainer
   257  				transport.queue.decryption.c <- elemsContainer
   258  			} else {
   259  				for _, elem := range elemsContainer.elems {
   260  					transport.PutMessageBuffer(elem.buffer)
   261  					transport.PutInboundElement(elem)
   262  				}
   263  				transport.PutInboundElementsContainer(elemsContainer)
   264  			}
   265  			delete(elemsByPeer, peer)
   266  		}
   267  	}
   268  }
   269  
   270  func (transport *Transport) RoutineDecryption(id int) {
   271  	var nonce [chacha20poly1305.NonceSize]byte
   272  
   273  	defer transport.logger.Debug("Routine: decryption worker - stopped", slog.Int("id", id))
   274  	transport.logger.Debug("Routine: decryption worker - started", slog.Int("id", id))
   275  
   276  	for elemsContainer := range transport.queue.decryption.c {
   277  		for _, elem := range elemsContainer.elems {
   278  			// split message into fields
   279  			counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   280  			content := elem.packet[MessageTransportOffsetContent:]
   281  
   282  			// decrypt and release to consumer
   283  			var err error
   284  			elem.counter = binary.LittleEndian.Uint64(counter)
   285  			// copy counter to nonce
   286  			binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   287  			elem.packet, err = elem.keypair.receive.Open(
   288  				content[:0],
   289  				nonce[:],
   290  				content,
   291  				nil,
   292  			)
   293  			if err != nil {
   294  				elem.packet = nil
   295  			}
   296  		}
   297  		elemsContainer.Unlock()
   298  	}
   299  }
   300  
   301  // Handles incoming packets related to handshake.
   302  func (transport *Transport) RoutineHandshake(id int) {
   303  	logger := transport.logger.With(slog.Int("id", id))
   304  
   305  	defer func() {
   306  		logger.Debug("Routine: handshake worker - stopped")
   307  		transport.queue.encryption.wg.Done()
   308  	}()
   309  	logger.Debug("Routine: handshake worker - started")
   310  
   311  	for elem := range transport.queue.handshake.c {
   312  		logger := logger.With(slog.String("from", elem.endpoint.DstToString()))
   313  
   314  		// handle cookie fields and ratelimiting
   315  
   316  		switch elem.msgType {
   317  
   318  		case MessageCookieReplyType:
   319  
   320  			// unmarshal packet
   321  
   322  			var reply MessageCookieReply
   323  			reader := bytes.NewReader(elem.packet)
   324  			err := binary.Read(reader, binary.LittleEndian, &reply)
   325  			if err != nil {
   326  				logger.Warn("Failed to decode cookie reply", slog.Any("error", err))
   327  				goto skip
   328  			}
   329  
   330  			// lookup peer from index
   331  
   332  			entry := transport.indexTable.Lookup(reply.Receiver)
   333  
   334  			if entry.peer == nil {
   335  				goto skip
   336  			}
   337  
   338  			// consume reply
   339  
   340  			if peer := entry.peer; peer.isRunning.Load() {
   341  				logger.Debug("Receiving cookie response")
   342  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   343  					logger.Warn("Could not decrypt invalid cookie response")
   344  				}
   345  			}
   346  
   347  			goto skip
   348  
   349  		case MessageInitiationType, MessageResponseType:
   350  
   351  			// check mac fields and maybe ratelimit
   352  
   353  			if !transport.cookieChecker.CheckMAC1(elem.packet) {
   354  				logger.Warn("Received packet with invalid mac1")
   355  				goto skip
   356  			}
   357  
   358  			// endpoints destination address is the source of the datagram
   359  
   360  			if transport.IsUnderLoad() {
   361  
   362  				// verify MAC2 field
   363  
   364  				if !transport.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   365  					if err := transport.SendHandshakeCookie(&elem); err != nil {
   366  						logger.Warn("Failed to send handshake cookie", slog.Any("error", err))
   367  					}
   368  					goto skip
   369  				}
   370  
   371  				// check ratelimiter
   372  
   373  				if !transport.rate.limiter.Allow(elem.endpoint.DstIP()) {
   374  					goto skip
   375  				}
   376  			}
   377  
   378  		default:
   379  			logger.Warn("Invalid packet ended up in the handshake queue")
   380  			goto skip
   381  		}
   382  
   383  		// handle handshake initiation/response content
   384  
   385  		switch elem.msgType {
   386  		case MessageInitiationType:
   387  
   388  			// unmarshal
   389  
   390  			var msg MessageInitiation
   391  			reader := bytes.NewReader(elem.packet)
   392  			err := binary.Read(reader, binary.LittleEndian, &msg)
   393  			if err != nil {
   394  				logger.Warn("Failed to decode initiation message", slog.Any("error", err))
   395  				goto skip
   396  			}
   397  
   398  			// consume initiation
   399  
   400  			peer := transport.ConsumeMessageInitiation(&msg)
   401  			if peer == nil {
   402  				logger.Warn("Received invalid initiation message")
   403  				goto skip
   404  			}
   405  
   406  			// update timers
   407  
   408  			peer.timersAnyAuthenticatedPacketTraversal()
   409  			peer.timersAnyAuthenticatedPacketReceived()
   410  
   411  			// update endpoint
   412  			peer.SetEndpoint(elem.endpoint)
   413  
   414  			logger.Debug("Received handshake initiation", slog.String("peer", peer.String()))
   415  			peer.rxBytes.Add(uint64(len(elem.packet)))
   416  
   417  			if err := peer.SendHandshakeResponse(); err != nil {
   418  				logger.Error("Failed to send handshake response", slog.Any("error", err))
   419  				goto skip
   420  			}
   421  
   422  		case MessageResponseType:
   423  
   424  			// unmarshal
   425  
   426  			var msg MessageResponse
   427  			reader := bytes.NewReader(elem.packet)
   428  			err := binary.Read(reader, binary.LittleEndian, &msg)
   429  			if err != nil {
   430  				logger.Warn("Failed to decode response message", slog.Any("error", err))
   431  				goto skip
   432  			}
   433  
   434  			// consume response
   435  
   436  			peer := transport.ConsumeMessageResponse(&msg)
   437  			if peer == nil {
   438  				logger.Warn("Received invalid response message")
   439  				goto skip
   440  			}
   441  
   442  			logger := logger.With(slog.String("peer", peer.String()))
   443  
   444  			// update endpoint
   445  			peer.SetEndpoint(elem.endpoint)
   446  
   447  			logger.Debug("Received handshake response")
   448  			peer.rxBytes.Add(uint64(len(elem.packet)))
   449  
   450  			// update timers
   451  
   452  			peer.timersAnyAuthenticatedPacketTraversal()
   453  			peer.timersAnyAuthenticatedPacketReceived()
   454  
   455  			// derive keypair
   456  			if err := peer.BeginSymmetricSession(); err != nil {
   457  				logger.Error("Failed to derive keypair", slog.Any("error", err))
   458  				goto skip
   459  			}
   460  
   461  			peer.timersSessionDerived()
   462  			peer.timersHandshakeComplete()
   463  			if err := peer.SendKeepalive(); err != nil {
   464  				logger.Error("Failed to send keepalive", slog.Any("error", err))
   465  				goto skip
   466  			}
   467  		}
   468  	skip:
   469  		transport.PutMessageBuffer(elem.buffer)
   470  	}
   471  }
   472  
   473  func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
   474  	t := peer.transport
   475  
   476  	logger := t.logger.With(slog.String("peer", peer.String()))
   477  
   478  	defer func() {
   479  		logger.Debug("Routine: sequential receiver - stopped")
   480  		peer.stopping.Done()
   481  	}()
   482  	logger.Debug("Routine: sequential receiver - started")
   483  
   484  	bufs := make([][]byte, 0, maxBatchSize)
   485  
   486  	peers := make([]types.NoisePublicKey, 0, maxBatchSize)
   487  	for i := 0; i < maxBatchSize; i++ {
   488  		peers = append(peers, peer.pk)
   489  	}
   490  
   491  	for elemsContainer := range peer.queue.inbound.c {
   492  		if elemsContainer == nil {
   493  			return
   494  		}
   495  		elemsContainer.Lock()
   496  		validTailPacket := -1
   497  		dataPacketReceived := false
   498  		rxBytesLen := uint64(0)
   499  		for i, elem := range elemsContainer.elems {
   500  			if elem.packet == nil {
   501  				// decryption failed
   502  				continue
   503  			}
   504  
   505  			if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   506  				continue
   507  			}
   508  
   509  			validTailPacket = i
   510  			if peer.ReceivedWithKeypair(elem.keypair) {
   511  				peer.SetEndpoint(elem.endpoint)
   512  				peer.timersHandshakeComplete()
   513  				if err := peer.SendStagedPackets(); err != nil {
   514  					logger.Warn("Failed to send staged packets", slog.Any("error", err))
   515  					continue
   516  				}
   517  			}
   518  			rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
   519  
   520  			if len(elem.packet) == 0 {
   521  				logger.Debug("Receiving keepalive packet")
   522  				continue
   523  			}
   524  			dataPacketReceived = true
   525  
   526  			bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
   527  		}
   528  
   529  		peer.rxBytes.Add(rxBytesLen)
   530  		if validTailPacket >= 0 {
   531  			peer.SetEndpoint(elemsContainer.elems[validTailPacket].endpoint)
   532  			if err := peer.keepKeyFreshReceiving(); err != nil {
   533  				logger.Warn("Failed to keep key fresh", slog.Any("error", err))
   534  				continue
   535  			}
   536  			peer.timersAnyAuthenticatedPacketTraversal()
   537  			peer.timersAnyAuthenticatedPacketReceived()
   538  		}
   539  		if dataPacketReceived {
   540  			peer.timersDataReceived()
   541  		}
   542  		if len(bufs) > 0 {
   543  			_, err := t.sourceSink.Write(bufs, peers, MessageTransportOffsetContent)
   544  			if err != nil && !t.isClosed() {
   545  				logger.Error("Failed to write packets to source sink", slog.Any("error", err))
   546  			}
   547  		}
   548  		for _, elem := range elemsContainer.elems {
   549  			t.PutMessageBuffer(elem.buffer)
   550  			t.PutInboundElement(elem)
   551  		}
   552  		bufs = bufs[:0]
   553  		t.PutInboundElementsContainer(elemsContainer)
   554  	}
   555  }