github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"math/rand"
    13  	"net"
    14  	"os"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/amnezia-vpn/amnezia-wg/conn"
    19  	"github.com/amnezia-vpn/amnezia-wg/tun"
    20  	"golang.org/x/crypto/chacha20poly1305"
    21  	"golang.org/x/net/ipv4"
    22  	"golang.org/x/net/ipv6"
    23  )
    24  
    25  /* Outbound flow
    26   *
    27   * 1. TUN queue
    28   * 2. Routing (sequential)
    29   * 3. Nonce assignment (sequential)
    30   * 4. Encryption (parallel)
    31   * 5. Transmission (sequential)
    32   *
    33   * The functions in this file occur (roughly) in the order in
    34   * which the packets are processed.
    35   *
    36   * Locking, Producers and Consumers
    37   *
    38   * The order of packets (per peer) must be maintained,
    39   * but encryption of packets happen out-of-order:
    40   *
    41   * The sequential consumers will attempt to take the lock,
    42   * workers release lock when they have completed work (encryption) on the packet.
    43   *
    44   * If the element is inserted into the "encryption queue",
    45   * the content is preceded by enough "junk" to contain the transport header
    46   * (to allow the construction of transport messages in-place)
    47   */
    48  
    49  type QueueOutboundElement struct {
    50  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    51  	packet  []byte                // slice of "buffer" (always!)
    52  	nonce   uint64                // nonce for encryption
    53  	keypair *Keypair              // keypair for encryption
    54  	peer    *Peer                 // related peer
    55  }
    56  
    57  type QueueOutboundElementsContainer struct {
    58  	sync.Mutex
    59  	elems []*QueueOutboundElement
    60  }
    61  
    62  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    63  	elem := device.GetOutboundElement()
    64  	elem.buffer = device.GetMessageBuffer()
    65  	elem.nonce = 0
    66  	// keypair and peer were cleared (if necessary) by clearPointers.
    67  	return elem
    68  }
    69  
    70  // clearPointers clears elem fields that contain pointers.
    71  // This makes the garbage collector's life easier and
    72  // avoids accidentally keeping other objects around unnecessarily.
    73  // It also reduces the possible collateral damage from use-after-free bugs.
    74  func (elem *QueueOutboundElement) clearPointers() {
    75  	elem.buffer = nil
    76  	elem.packet = nil
    77  	elem.keypair = nil
    78  	elem.peer = nil
    79  }
    80  
    81  /* Queues a keepalive if no packets are queued for peer
    82   */
    83  func (peer *Peer) SendKeepalive() {
    84  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
    85  		elem := peer.device.NewOutboundElement()
    86  		elemsContainer := peer.device.GetOutboundElementsContainer()
    87  		elemsContainer.elems = append(elemsContainer.elems, elem)
    88  		select {
    89  		case peer.queue.staged <- elemsContainer:
    90  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    91  		default:
    92  			peer.device.PutMessageBuffer(elem.buffer)
    93  			peer.device.PutOutboundElement(elem)
    94  			peer.device.PutOutboundElementsContainer(elemsContainer)
    95  		}
    96  	}
    97  	peer.SendStagedPackets()
    98  }
    99  
   100  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   101  	if !isRetry {
   102  		peer.timers.handshakeAttempts.Store(0)
   103  	}
   104  
   105  	peer.handshake.mutex.RLock()
   106  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   107  		peer.handshake.mutex.RUnlock()
   108  		return nil
   109  	}
   110  	peer.handshake.mutex.RUnlock()
   111  
   112  	peer.handshake.mutex.Lock()
   113  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   114  		peer.handshake.mutex.Unlock()
   115  		return nil
   116  	}
   117  	peer.handshake.lastSentHandshake = time.Now()
   118  	peer.handshake.mutex.Unlock()
   119  
   120  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   121  
   122  	msg, err := peer.device.CreateMessageInitiation(peer)
   123  	if err != nil {
   124  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   125  		return err
   126  	}
   127  	var sendBuffer [][]byte
   128  	// so only packet processed for cookie generation
   129  	var junkedHeader []byte
   130  	if peer.device.isAdvancedSecurityOn() {
   131  		peer.device.aSecMux.RLock()
   132  		junks, err := peer.createJunkPackets()
   133  		peer.device.aSecMux.RUnlock()
   134  
   135  		if err != nil {
   136  			peer.device.log.Errorf("%v - %v", peer, err)
   137  			return err
   138  		}
   139  
   140  		if len(junks) > 0 {
   141  			err = peer.SendBuffers(junks)
   142  
   143  			if err != nil {
   144  				peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
   145  				return err
   146  			}
   147  		}
   148  
   149  		peer.device.aSecMux.RLock()
   150  		if peer.device.aSecCfg.initPacketJunkSize != 0 {
   151  			buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
   152  			writer := bytes.NewBuffer(buf[:0])
   153  			err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
   154  			if err != nil {
   155  				peer.device.log.Errorf("%v - %v", peer, err)
   156  				peer.device.aSecMux.RUnlock()
   157  				return err
   158  			}
   159  			junkedHeader = writer.Bytes()
   160  		}
   161  		peer.device.aSecMux.RUnlock()
   162  	}
   163  
   164  	var buf [MessageInitiationSize]byte
   165  	writer := bytes.NewBuffer(buf[:0])
   166  	binary.Write(writer, binary.LittleEndian, msg)
   167  	packet := writer.Bytes()
   168  	peer.cookieGenerator.AddMacs(packet)
   169  	junkedHeader = append(junkedHeader, packet...)
   170  
   171  	peer.timersAnyAuthenticatedPacketTraversal()
   172  	peer.timersAnyAuthenticatedPacketSent()
   173  
   174  	sendBuffer = append(sendBuffer, junkedHeader)
   175  
   176  	err = peer.SendBuffers(sendBuffer)
   177  	if err != nil {
   178  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   179  	}
   180  	peer.timersHandshakeInitiated()
   181  
   182  	return err
   183  }
   184  
   185  func (peer *Peer) SendHandshakeResponse() error {
   186  	peer.handshake.mutex.Lock()
   187  	peer.handshake.lastSentHandshake = time.Now()
   188  	peer.handshake.mutex.Unlock()
   189  
   190  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   191  
   192  	response, err := peer.device.CreateMessageResponse(peer)
   193  	if err != nil {
   194  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   195  		return err
   196  	}
   197  	var junkedHeader []byte
   198  	if peer.device.isAdvancedSecurityOn() {
   199  		peer.device.aSecMux.RLock()
   200  		if peer.device.aSecCfg.responsePacketJunkSize != 0 {
   201  			buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
   202  			writer := bytes.NewBuffer(buf[:0])
   203  			err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
   204  			if err != nil {
   205  				peer.device.aSecMux.RUnlock()
   206  				peer.device.log.Errorf("%v - %v", peer, err)
   207  				return err
   208  			}
   209  			junkedHeader = writer.Bytes()
   210  		}
   211  		peer.device.aSecMux.RUnlock()
   212  	}
   213  	var buf [MessageResponseSize]byte
   214  	writer := bytes.NewBuffer(buf[:0])
   215  
   216  	binary.Write(writer, binary.LittleEndian, response)
   217  	packet := writer.Bytes()
   218  	peer.cookieGenerator.AddMacs(packet)
   219  	junkedHeader = append(junkedHeader, packet...)
   220  
   221  	err = peer.BeginSymmetricSession()
   222  	if err != nil {
   223  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   224  		return err
   225  	}
   226  
   227  	peer.timersSessionDerived()
   228  	peer.timersAnyAuthenticatedPacketTraversal()
   229  	peer.timersAnyAuthenticatedPacketSent()
   230  
   231  	// TODO: allocation could be avoided
   232  	err = peer.SendBuffers([][]byte{junkedHeader})
   233  	if err != nil {
   234  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   235  	}
   236  	return err
   237  }
   238  
   239  func (device *Device) SendHandshakeCookie(
   240  	initiatingElem *QueueHandshakeElement,
   241  ) error {
   242  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   243  
   244  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   245  	reply, err := device.cookieChecker.CreateReply(
   246  		initiatingElem.packet,
   247  		sender,
   248  		initiatingElem.endpoint.DstToBytes(),
   249  	)
   250  	if err != nil {
   251  		device.log.Errorf("Failed to create cookie reply: %v", err)
   252  		return err
   253  	}
   254  
   255  	var buf [MessageCookieReplySize]byte
   256  	writer := bytes.NewBuffer(buf[:0])
   257  	binary.Write(writer, binary.LittleEndian, reply)
   258  	// TODO: allocation could be avoided
   259  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   260  	return nil
   261  }
   262  
   263  func (peer *Peer) keepKeyFreshSending() {
   264  	keypair := peer.keypairs.Current()
   265  	if keypair == nil {
   266  		return
   267  	}
   268  	nonce := keypair.sendNonce.Load()
   269  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   270  		peer.SendHandshakeInitiation(false)
   271  	}
   272  }
   273  
   274  func (device *Device) RoutineReadFromTUN() {
   275  	defer func() {
   276  		device.log.Verbosef("Routine: TUN reader - stopped")
   277  		device.state.stopping.Done()
   278  		device.queue.encryption.wg.Done()
   279  	}()
   280  
   281  	device.log.Verbosef("Routine: TUN reader - started")
   282  
   283  	var (
   284  		batchSize   = device.BatchSize()
   285  		readErr     error
   286  		elems       = make([]*QueueOutboundElement, batchSize)
   287  		bufs        = make([][]byte, batchSize)
   288  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   289  		count       = 0
   290  		sizes       = make([]int, batchSize)
   291  		offset      = MessageTransportHeaderSize
   292  	)
   293  
   294  	for i := range elems {
   295  		elems[i] = device.NewOutboundElement()
   296  		bufs[i] = elems[i].buffer[:]
   297  	}
   298  
   299  	defer func() {
   300  		for _, elem := range elems {
   301  			if elem != nil {
   302  				device.PutMessageBuffer(elem.buffer)
   303  				device.PutOutboundElement(elem)
   304  			}
   305  		}
   306  	}()
   307  
   308  	for {
   309  		// read packets
   310  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   311  		for i := 0; i < count; i++ {
   312  			if sizes[i] < 1 {
   313  				continue
   314  			}
   315  
   316  			elem := elems[i]
   317  			elem.packet = bufs[i][offset : offset+sizes[i]]
   318  
   319  			// lookup peer
   320  			var peer *Peer
   321  			switch elem.packet[0] >> 4 {
   322  			case 4:
   323  				if len(elem.packet) < ipv4.HeaderLen {
   324  					continue
   325  				}
   326  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   327  				peer = device.allowedips.Lookup(dst)
   328  
   329  			case 6:
   330  				if len(elem.packet) < ipv6.HeaderLen {
   331  					continue
   332  				}
   333  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   334  				peer = device.allowedips.Lookup(dst)
   335  
   336  			default:
   337  				device.log.Verbosef("Received packet with unknown IP version")
   338  			}
   339  
   340  			if peer == nil {
   341  				continue
   342  			}
   343  			elemsForPeer, ok := elemsByPeer[peer]
   344  			if !ok {
   345  				elemsForPeer = device.GetOutboundElementsContainer()
   346  				elemsByPeer[peer] = elemsForPeer
   347  			}
   348  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   349  			elems[i] = device.NewOutboundElement()
   350  			bufs[i] = elems[i].buffer[:]
   351  		}
   352  
   353  		for peer, elemsForPeer := range elemsByPeer {
   354  			if peer.isRunning.Load() {
   355  				peer.StagePackets(elemsForPeer)
   356  				peer.SendStagedPackets()
   357  			} else {
   358  				for _, elem := range elemsForPeer.elems {
   359  					device.PutMessageBuffer(elem.buffer)
   360  					device.PutOutboundElement(elem)
   361  				}
   362  				device.PutOutboundElementsContainer(elemsForPeer)
   363  			}
   364  			delete(elemsByPeer, peer)
   365  		}
   366  
   367  		if readErr != nil {
   368  			if errors.Is(readErr, tun.ErrTooManySegments) {
   369  				// TODO: record stat for this
   370  				// This will happen if MSS is surprisingly small (< 576)
   371  				// coincident with reasonably high throughput.
   372  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   373  				continue
   374  			}
   375  			if !device.isClosed() {
   376  				if !errors.Is(readErr, os.ErrClosed) {
   377  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   378  				}
   379  				go device.Close()
   380  			}
   381  			return
   382  		}
   383  	}
   384  }
   385  
   386  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   387  	for {
   388  		select {
   389  		case peer.queue.staged <- elems:
   390  			return
   391  		default:
   392  		}
   393  		select {
   394  		case tooOld := <-peer.queue.staged:
   395  			for _, elem := range tooOld.elems {
   396  				peer.device.PutMessageBuffer(elem.buffer)
   397  				peer.device.PutOutboundElement(elem)
   398  			}
   399  			peer.device.PutOutboundElementsContainer(tooOld)
   400  		default:
   401  		}
   402  	}
   403  }
   404  
   405  func (peer *Peer) SendStagedPackets() {
   406  top:
   407  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   408  		return
   409  	}
   410  
   411  	keypair := peer.keypairs.Current()
   412  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   413  		peer.SendHandshakeInitiation(false)
   414  		return
   415  	}
   416  
   417  	for {
   418  		var elemsContainerOOO *QueueOutboundElementsContainer
   419  		select {
   420  		case elemsContainer := <-peer.queue.staged:
   421  			i := 0
   422  			for _, elem := range elemsContainer.elems {
   423  				elem.peer = peer
   424  				elem.nonce = keypair.sendNonce.Add(1) - 1
   425  				if elem.nonce >= RejectAfterMessages {
   426  					keypair.sendNonce.Store(RejectAfterMessages)
   427  					if elemsContainerOOO == nil {
   428  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
   429  					}
   430  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   431  					continue
   432  				} else {
   433  					elemsContainer.elems[i] = elem
   434  					i++
   435  				}
   436  
   437  				elem.keypair = keypair
   438  			}
   439  			elemsContainer.Lock()
   440  			elemsContainer.elems = elemsContainer.elems[:i]
   441  
   442  			if elemsContainerOOO != nil {
   443  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   444  			}
   445  
   446  			if len(elemsContainer.elems) == 0 {
   447  				peer.device.PutOutboundElementsContainer(elemsContainer)
   448  				goto top
   449  			}
   450  
   451  			// add to parallel and sequential queue
   452  			if peer.isRunning.Load() {
   453  				peer.queue.outbound.c <- elemsContainer
   454  				peer.device.queue.encryption.c <- elemsContainer
   455  			} else {
   456  				for _, elem := range elemsContainer.elems {
   457  					peer.device.PutMessageBuffer(elem.buffer)
   458  					peer.device.PutOutboundElement(elem)
   459  				}
   460  				peer.device.PutOutboundElementsContainer(elemsContainer)
   461  			}
   462  
   463  			if elemsContainerOOO != nil {
   464  				goto top
   465  			}
   466  		default:
   467  			return
   468  		}
   469  	}
   470  }
   471  
   472  func (peer *Peer) createJunkPackets() ([][]byte, error) {
   473  	if peer.device.aSecCfg.junkPacketCount == 0 {
   474  		return nil, nil
   475  	}
   476  
   477  	junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
   478  	for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
   479  		packetSize := rand.Intn(
   480  			peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
   481  		) + peer.device.aSecCfg.junkPacketMinSize
   482  
   483  		junk, err := randomJunkWithSize(packetSize)
   484  		if err != nil {
   485  			peer.device.log.Errorf(
   486  				"%v - Failed to create junk packet: %v",
   487  				peer,
   488  				err,
   489  			)
   490  			return nil, err
   491  		}
   492  		junks = append(junks, junk)
   493  	}
   494  	return junks, nil
   495  }
   496  
   497  func (peer *Peer) FlushStagedPackets() {
   498  	for {
   499  		select {
   500  		case elemsContainer := <-peer.queue.staged:
   501  			for _, elem := range elemsContainer.elems {
   502  				peer.device.PutMessageBuffer(elem.buffer)
   503  				peer.device.PutOutboundElement(elem)
   504  			}
   505  			peer.device.PutOutboundElementsContainer(elemsContainer)
   506  		default:
   507  			return
   508  		}
   509  	}
   510  }
   511  
   512  func calculatePaddingSize(packetSize, mtu int) int {
   513  	lastUnit := packetSize
   514  	if mtu == 0 {
   515  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   516  	}
   517  	if lastUnit > mtu {
   518  		lastUnit %= mtu
   519  	}
   520  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   521  	if paddedSize > mtu {
   522  		paddedSize = mtu
   523  	}
   524  	return paddedSize - lastUnit
   525  }
   526  
   527  /* Encrypts the elements in the queue
   528   * and marks them for sequential consumption (by releasing the mutex)
   529   *
   530   * Obs. One instance per core
   531   */
   532  func (device *Device) RoutineEncryption(id int) {
   533  	var paddingZeros [PaddingMultiple]byte
   534  	var nonce [chacha20poly1305.NonceSize]byte
   535  
   536  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   537  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   538  
   539  	for elemsContainer := range device.queue.encryption.c {
   540  		for _, elem := range elemsContainer.elems {
   541  			// populate header fields
   542  			header := elem.buffer[:MessageTransportHeaderSize]
   543  
   544  			fieldType := header[0:4]
   545  			fieldReceiver := header[4:8]
   546  			fieldNonce := header[8:16]
   547  
   548  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   549  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   550  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   551  
   552  			// pad content to multiple of 16
   553  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   554  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   555  
   556  			// encrypt content and release to consumer
   557  
   558  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   559  			elem.packet = elem.keypair.send.Seal(
   560  				header,
   561  				nonce[:],
   562  				elem.packet,
   563  				nil,
   564  			)
   565  		}
   566  		elemsContainer.Unlock()
   567  	}
   568  }
   569  
   570  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   571  	device := peer.device
   572  	defer func() {
   573  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   574  		peer.stopping.Done()
   575  	}()
   576  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   577  
   578  	bufs := make([][]byte, 0, maxBatchSize)
   579  
   580  	for elemsContainer := range peer.queue.outbound.c {
   581  		bufs = bufs[:0]
   582  		if elemsContainer == nil {
   583  			return
   584  		}
   585  		if !peer.isRunning.Load() {
   586  			// peer has been stopped; return re-usable elems to the shared pool.
   587  			// This is an optimization only. It is possible for the peer to be stopped
   588  			// immediately after this check, in which case, elem will get processed.
   589  			// The timers and SendBuffers code are resilient to a few stragglers.
   590  			// TODO: rework peer shutdown order to ensure
   591  			// that we never accidentally keep timers alive longer than necessary.
   592  			elemsContainer.Lock()
   593  			for _, elem := range elemsContainer.elems {
   594  				device.PutMessageBuffer(elem.buffer)
   595  				device.PutOutboundElement(elem)
   596  			}
   597  			continue
   598  		}
   599  		dataSent := false
   600  		elemsContainer.Lock()
   601  		for _, elem := range elemsContainer.elems {
   602  			if len(elem.packet) != MessageKeepaliveSize {
   603  				dataSent = true
   604  			}
   605  			bufs = append(bufs, elem.packet)
   606  		}
   607  
   608  		peer.timersAnyAuthenticatedPacketTraversal()
   609  		peer.timersAnyAuthenticatedPacketSent()
   610  
   611  		err := peer.SendBuffers(bufs)
   612  		if dataSent {
   613  			peer.timersDataSent()
   614  		}
   615  		for _, elem := range elemsContainer.elems {
   616  			device.PutMessageBuffer(elem.buffer)
   617  			device.PutOutboundElement(elem)
   618  		}
   619  		device.PutOutboundElementsContainer(elemsContainer)
   620  		if err != nil {
   621  			var errGSO conn.ErrUDPGSODisabled
   622  			if errors.As(err, &errGSO) {
   623  				device.log.Verbosef(err.Error())
   624  				err = errGSO.RetryErr
   625  			}
   626  		}
   627  		if err != nil {
   628  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   629  			continue
   630  		}
   631  
   632  		peer.keepKeyFreshSending()
   633  	}
   634  }