github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"os"
    14  	"sync"
    15  	"time"
    16  
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  	"github.com/koomox/wireguard-go/tun"
    21  )
    22  
    23  /* Outbound flow
    24   *
    25   * 1. TUN queue
    26   * 2. Routing (sequential)
    27   * 3. Nonce assignment (sequential)
    28   * 4. Encryption (parallel)
    29   * 5. Transmission (sequential)
    30   *
    31   * The functions in this file occur (roughly) in the order in
    32   * which the packets are processed.
    33   *
    34   * Locking, Producers and Consumers
    35   *
    36   * The order of packets (per peer) must be maintained,
    37   * but encryption of packets happen out-of-order:
    38   *
    39   * The sequential consumers will attempt to take the lock,
    40   * workers release lock when they have completed work (encryption) on the packet.
    41   *
    42   * If the element is inserted into the "encryption queue",
    43   * the content is preceded by enough "junk" to contain the transport header
    44   * (to allow the construction of transport messages in-place)
    45   */
    46  
    47  type QueueOutboundElement struct {
    48  	sync.Mutex
    49  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    50  	packet  []byte                // slice of "buffer" (always!)
    51  	nonce   uint64                // nonce for encryption
    52  	keypair *Keypair              // keypair for encryption
    53  	peer    *Peer                 // related peer
    54  }
    55  
    56  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    57  	elem := device.GetOutboundElement()
    58  	elem.buffer = device.GetMessageBuffer()
    59  	elem.Mutex = sync.Mutex{}
    60  	elem.nonce = 0
    61  	// keypair and peer were cleared (if necessary) by clearPointers.
    62  	return elem
    63  }
    64  
    65  // clearPointers clears elem fields that contain pointers.
    66  // This makes the garbage collector's life easier and
    67  // avoids accidentally keeping other objects around unnecessarily.
    68  // It also reduces the possible collateral damage from use-after-free bugs.
    69  func (elem *QueueOutboundElement) clearPointers() {
    70  	elem.buffer = nil
    71  	elem.packet = nil
    72  	elem.keypair = nil
    73  	elem.peer = nil
    74  }
    75  
    76  /* Queues a keepalive if no packets are queued for peer
    77   */
    78  func (peer *Peer) SendKeepalive() {
    79  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
    80  		elem := peer.device.NewOutboundElement()
    81  		elems := peer.device.GetOutboundElementsSlice()
    82  		*elems = append(*elems, elem)
    83  		select {
    84  		case peer.queue.staged <- elems:
    85  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    86  		default:
    87  			peer.device.PutMessageBuffer(elem.buffer)
    88  			peer.device.PutOutboundElement(elem)
    89  			peer.device.PutOutboundElementsSlice(elems)
    90  		}
    91  	}
    92  	peer.SendStagedPackets()
    93  }
    94  
    95  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
    96  	if !isRetry {
    97  		peer.timers.handshakeAttempts.Store(0)
    98  	}
    99  
   100  	peer.handshake.mutex.RLock()
   101  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   102  		peer.handshake.mutex.RUnlock()
   103  		return nil
   104  	}
   105  	peer.handshake.mutex.RUnlock()
   106  
   107  	peer.handshake.mutex.Lock()
   108  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   109  		peer.handshake.mutex.Unlock()
   110  		return nil
   111  	}
   112  	peer.handshake.lastSentHandshake = time.Now()
   113  	peer.handshake.mutex.Unlock()
   114  
   115  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   116  
   117  	msg, err := peer.device.CreateMessageInitiation(peer)
   118  	if err != nil {
   119  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   120  		return err
   121  	}
   122  
   123  	var buf [MessageInitiationSize]byte
   124  	writer := bytes.NewBuffer(buf[:0])
   125  	binary.Write(writer, binary.LittleEndian, msg)
   126  	packet := writer.Bytes()
   127  	peer.cookieGenerator.AddMacs(packet)
   128  
   129  	peer.timersAnyAuthenticatedPacketTraversal()
   130  	peer.timersAnyAuthenticatedPacketSent()
   131  
   132  	err = peer.SendBuffers([][]byte{packet})
   133  	if err != nil {
   134  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   135  	}
   136  	peer.timersHandshakeInitiated()
   137  
   138  	return err
   139  }
   140  
   141  func (peer *Peer) SendHandshakeResponse() error {
   142  	peer.handshake.mutex.Lock()
   143  	peer.handshake.lastSentHandshake = time.Now()
   144  	peer.handshake.mutex.Unlock()
   145  
   146  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   147  
   148  	response, err := peer.device.CreateMessageResponse(peer)
   149  	if err != nil {
   150  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   151  		return err
   152  	}
   153  
   154  	var buf [MessageResponseSize]byte
   155  	writer := bytes.NewBuffer(buf[:0])
   156  	binary.Write(writer, binary.LittleEndian, response)
   157  	packet := writer.Bytes()
   158  	peer.cookieGenerator.AddMacs(packet)
   159  
   160  	err = peer.BeginSymmetricSession()
   161  	if err != nil {
   162  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   163  		return err
   164  	}
   165  
   166  	peer.timersSessionDerived()
   167  	peer.timersAnyAuthenticatedPacketTraversal()
   168  	peer.timersAnyAuthenticatedPacketSent()
   169  
   170  	// TODO: allocation could be avoided
   171  	err = peer.SendBuffers([][]byte{packet})
   172  	if err != nil {
   173  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   174  	}
   175  	return err
   176  }
   177  
   178  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   179  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   180  
   181  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   182  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   183  	if err != nil {
   184  		device.log.Errorf("Failed to create cookie reply: %v", err)
   185  		return err
   186  	}
   187  
   188  	var buf [MessageCookieReplySize]byte
   189  	writer := bytes.NewBuffer(buf[:0])
   190  	binary.Write(writer, binary.LittleEndian, reply)
   191  	// TODO: allocation could be avoided
   192  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   193  	return nil
   194  }
   195  
   196  func (peer *Peer) keepKeyFreshSending() {
   197  	keypair := peer.keypairs.Current()
   198  	if keypair == nil {
   199  		return
   200  	}
   201  	nonce := keypair.sendNonce.Load()
   202  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   203  		peer.SendHandshakeInitiation(false)
   204  	}
   205  }
   206  
   207  func (device *Device) RoutineReadFromTUN() {
   208  	defer func() {
   209  		device.log.Verbosef("Routine: TUN reader - stopped")
   210  		device.state.stopping.Done()
   211  		device.queue.encryption.wg.Done()
   212  	}()
   213  
   214  	device.log.Verbosef("Routine: TUN reader - started")
   215  
   216  	var (
   217  		batchSize   = device.BatchSize()
   218  		readErr     error
   219  		elems       = make([]*QueueOutboundElement, batchSize)
   220  		bufs        = make([][]byte, batchSize)
   221  		elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
   222  		count       = 0
   223  		sizes       = make([]int, batchSize)
   224  		offset      = MessageTransportHeaderSize
   225  	)
   226  
   227  	for i := range elems {
   228  		elems[i] = device.NewOutboundElement()
   229  		bufs[i] = elems[i].buffer[:]
   230  	}
   231  
   232  	defer func() {
   233  		for _, elem := range elems {
   234  			if elem != nil {
   235  				device.PutMessageBuffer(elem.buffer)
   236  				device.PutOutboundElement(elem)
   237  			}
   238  		}
   239  	}()
   240  
   241  	for {
   242  		// read packets
   243  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   244  		for i := 0; i < count; i++ {
   245  			if sizes[i] < 1 {
   246  				continue
   247  			}
   248  
   249  			elem := elems[i]
   250  			elem.packet = bufs[i][offset : offset+sizes[i]]
   251  
   252  			// lookup peer
   253  			var peer *Peer
   254  			switch elem.packet[0] >> 4 {
   255  			case 4:
   256  				if len(elem.packet) < ipv4.HeaderLen {
   257  					continue
   258  				}
   259  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   260  				peer = device.allowedips.Lookup(dst)
   261  
   262  			case 6:
   263  				if len(elem.packet) < ipv6.HeaderLen {
   264  					continue
   265  				}
   266  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   267  				peer = device.allowedips.Lookup(dst)
   268  
   269  			default:
   270  				device.log.Verbosef("Received packet with unknown IP version")
   271  			}
   272  
   273  			if peer == nil {
   274  				continue
   275  			}
   276  			elemsForPeer, ok := elemsByPeer[peer]
   277  			if !ok {
   278  				elemsForPeer = device.GetOutboundElementsSlice()
   279  				elemsByPeer[peer] = elemsForPeer
   280  			}
   281  			*elemsForPeer = append(*elemsForPeer, elem)
   282  			elems[i] = device.NewOutboundElement()
   283  			bufs[i] = elems[i].buffer[:]
   284  		}
   285  
   286  		for peer, elemsForPeer := range elemsByPeer {
   287  			if peer.isRunning.Load() {
   288  				peer.StagePackets(elemsForPeer)
   289  				peer.SendStagedPackets()
   290  			} else {
   291  				for _, elem := range *elemsForPeer {
   292  					device.PutMessageBuffer(elem.buffer)
   293  					device.PutOutboundElement(elem)
   294  				}
   295  				device.PutOutboundElementsSlice(elemsForPeer)
   296  			}
   297  			delete(elemsByPeer, peer)
   298  		}
   299  
   300  		if readErr != nil {
   301  			if errors.Is(readErr, tun.ErrTooManySegments) {
   302  				// TODO: record stat for this
   303  				// This will happen if MSS is surprisingly small (< 576)
   304  				// coincident with reasonably high throughput.
   305  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   306  				continue
   307  			}
   308  			if !device.isClosed() {
   309  				if !errors.Is(readErr, os.ErrClosed) {
   310  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   311  				}
   312  				go device.Close()
   313  			}
   314  			return
   315  		}
   316  	}
   317  }
   318  
   319  func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
   320  	for {
   321  		select {
   322  		case peer.queue.staged <- elems:
   323  			return
   324  		default:
   325  		}
   326  		select {
   327  		case tooOld := <-peer.queue.staged:
   328  			for _, elem := range *tooOld {
   329  				peer.device.PutMessageBuffer(elem.buffer)
   330  				peer.device.PutOutboundElement(elem)
   331  			}
   332  			peer.device.PutOutboundElementsSlice(tooOld)
   333  		default:
   334  		}
   335  	}
   336  }
   337  
   338  func (peer *Peer) SendStagedPackets() {
   339  top:
   340  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   341  		return
   342  	}
   343  
   344  	keypair := peer.keypairs.Current()
   345  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   346  		peer.SendHandshakeInitiation(false)
   347  		return
   348  	}
   349  
   350  	for {
   351  		var elemsOOO *[]*QueueOutboundElement
   352  		select {
   353  		case elems := <-peer.queue.staged:
   354  			i := 0
   355  			for _, elem := range *elems {
   356  				elem.peer = peer
   357  				elem.nonce = keypair.sendNonce.Add(1) - 1
   358  				if elem.nonce >= RejectAfterMessages {
   359  					keypair.sendNonce.Store(RejectAfterMessages)
   360  					if elemsOOO == nil {
   361  						elemsOOO = peer.device.GetOutboundElementsSlice()
   362  					}
   363  					*elemsOOO = append(*elemsOOO, elem)
   364  					continue
   365  				} else {
   366  					(*elems)[i] = elem
   367  					i++
   368  				}
   369  
   370  				elem.keypair = keypair
   371  				elem.Lock()
   372  			}
   373  			*elems = (*elems)[:i]
   374  
   375  			if elemsOOO != nil {
   376  				peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
   377  			}
   378  
   379  			if len(*elems) == 0 {
   380  				peer.device.PutOutboundElementsSlice(elems)
   381  				goto top
   382  			}
   383  
   384  			// add to parallel and sequential queue
   385  			if peer.isRunning.Load() {
   386  				peer.queue.outbound.c <- elems
   387  				for _, elem := range *elems {
   388  					peer.device.queue.encryption.c <- elem
   389  				}
   390  			} else {
   391  				for _, elem := range *elems {
   392  					peer.device.PutMessageBuffer(elem.buffer)
   393  					peer.device.PutOutboundElement(elem)
   394  				}
   395  				peer.device.PutOutboundElementsSlice(elems)
   396  			}
   397  
   398  			if elemsOOO != nil {
   399  				goto top
   400  			}
   401  		default:
   402  			return
   403  		}
   404  	}
   405  }
   406  
   407  func (peer *Peer) FlushStagedPackets() {
   408  	for {
   409  		select {
   410  		case elems := <-peer.queue.staged:
   411  			for _, elem := range *elems {
   412  				peer.device.PutMessageBuffer(elem.buffer)
   413  				peer.device.PutOutboundElement(elem)
   414  			}
   415  			peer.device.PutOutboundElementsSlice(elems)
   416  		default:
   417  			return
   418  		}
   419  	}
   420  }
   421  
   422  func calculatePaddingSize(packetSize, mtu int) int {
   423  	lastUnit := packetSize
   424  	if mtu == 0 {
   425  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   426  	}
   427  	if lastUnit > mtu {
   428  		lastUnit %= mtu
   429  	}
   430  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   431  	if paddedSize > mtu {
   432  		paddedSize = mtu
   433  	}
   434  	return paddedSize - lastUnit
   435  }
   436  
   437  /* Encrypts the elements in the queue
   438   * and marks them for sequential consumption (by releasing the mutex)
   439   *
   440   * Obs. One instance per core
   441   */
   442  func (device *Device) RoutineEncryption(id int) {
   443  	var paddingZeros [PaddingMultiple]byte
   444  	var nonce [chacha20poly1305.NonceSize]byte
   445  
   446  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   447  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   448  
   449  	for elem := range device.queue.encryption.c {
   450  		// populate header fields
   451  		header := elem.buffer[:MessageTransportHeaderSize]
   452  
   453  		fieldType := header[0:4]
   454  		fieldReceiver := header[4:8]
   455  		fieldNonce := header[8:16]
   456  
   457  		binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   458  		binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   459  		binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   460  
   461  		// pad content to multiple of 16
   462  		paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   463  		elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   464  
   465  		// encrypt content and release to consumer
   466  
   467  		binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   468  		elem.packet = elem.keypair.send.Seal(
   469  			header,
   470  			nonce[:],
   471  			elem.packet,
   472  			nil,
   473  		)
   474  		elem.Unlock()
   475  	}
   476  }
   477  
   478  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   479  	device := peer.device
   480  	defer func() {
   481  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   482  		peer.stopping.Done()
   483  	}()
   484  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   485  
   486  	bufs := make([][]byte, 0, maxBatchSize)
   487  
   488  	for elems := range peer.queue.outbound.c {
   489  		bufs = bufs[:0]
   490  		if elems == nil {
   491  			return
   492  		}
   493  		if !peer.isRunning.Load() {
   494  			// peer has been stopped; return re-usable elems to the shared pool.
   495  			// This is an optimization only. It is possible for the peer to be stopped
   496  			// immediately after this check, in which case, elem will get processed.
   497  			// The timers and SendBuffers code are resilient to a few stragglers.
   498  			// TODO: rework peer shutdown order to ensure
   499  			// that we never accidentally keep timers alive longer than necessary.
   500  			for _, elem := range *elems {
   501  				elem.Lock()
   502  				device.PutMessageBuffer(elem.buffer)
   503  				device.PutOutboundElement(elem)
   504  			}
   505  			continue
   506  		}
   507  		dataSent := false
   508  		for _, elem := range *elems {
   509  			elem.Lock()
   510  			if len(elem.packet) != MessageKeepaliveSize {
   511  				dataSent = true
   512  			}
   513  			bufs = append(bufs, elem.packet)
   514  		}
   515  
   516  		peer.timersAnyAuthenticatedPacketTraversal()
   517  		peer.timersAnyAuthenticatedPacketSent()
   518  
   519  		err := peer.SendBuffers(bufs)
   520  		if dataSent {
   521  			peer.timersDataSent()
   522  		}
   523  		for _, elem := range *elems {
   524  			device.PutMessageBuffer(elem.buffer)
   525  			device.PutOutboundElement(elem)
   526  		}
   527  		device.PutOutboundElementsSlice(elems)
   528  		if err != nil {
   529  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   530  			continue
   531  		}
   532  
   533  		peer.keepKeyFreshSending()
   534  	}
   535  }