github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"os"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/sagernet/wireguard-go/conn"
    18  	"github.com/sagernet/wireguard-go/tun"
    19  	"golang.org/x/crypto/chacha20poly1305"
    20  	"golang.org/x/net/ipv4"
    21  	"golang.org/x/net/ipv6"
    22  )
    23  
    24  /* Outbound flow
    25   *
    26   * 1. TUN queue
    27   * 2. Routing (sequential)
    28   * 3. Nonce assignment (sequential)
    29   * 4. Encryption (parallel)
    30   * 5. Transmission (sequential)
    31   *
    32   * The functions in this file occur (roughly) in the order in
    33   * which the packets are processed.
    34   *
    35   * Locking, Producers and Consumers
    36   *
    37   * The order of packets (per peer) must be maintained,
    38   * but encryption of packets happen out-of-order:
    39   *
    40   * The sequential consumers will attempt to take the lock,
    41   * workers release lock when they have completed work (encryption) on the packet.
    42   *
    43   * If the element is inserted into the "encryption queue",
    44   * the content is preceded by enough "junk" to contain the transport header
    45   * (to allow the construction of transport messages in-place)
    46   */
    47  
    48  type QueueOutboundElement struct {
    49  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    50  	packet  []byte                // slice of "buffer" (always!)
    51  	nonce   uint64                // nonce for encryption
    52  	keypair *Keypair              // keypair for encryption
    53  	peer    *Peer                 // related peer
    54  }
    55  
    56  type QueueOutboundElementsContainer struct {
    57  	sync.Mutex
    58  	elems []*QueueOutboundElement
    59  }
    60  
    61  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    62  	elem := device.GetOutboundElement()
    63  	elem.buffer = device.GetMessageBuffer()
    64  	elem.nonce = 0
    65  	// keypair and peer were cleared (if necessary) by clearPointers.
    66  	return elem
    67  }
    68  
    69  // clearPointers clears elem fields that contain pointers.
    70  // This makes the garbage collector's life easier and
    71  // avoids accidentally keeping other objects around unnecessarily.
    72  // It also reduces the possible collateral damage from use-after-free bugs.
    73  func (elem *QueueOutboundElement) clearPointers() {
    74  	elem.buffer = nil
    75  	elem.packet = nil
    76  	elem.keypair = nil
    77  	elem.peer = nil
    78  }
    79  
    80  /* Queues a keepalive if no packets are queued for peer
    81   */
    82  func (peer *Peer) SendKeepalive() {
    83  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
    84  		elem := peer.device.NewOutboundElement()
    85  		elemsContainer := peer.device.GetOutboundElementsContainer()
    86  		elemsContainer.elems = append(elemsContainer.elems, elem)
    87  		select {
    88  		case peer.queue.staged <- elemsContainer:
    89  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    90  		default:
    91  			peer.device.PutMessageBuffer(elem.buffer)
    92  			peer.device.PutOutboundElement(elem)
    93  			peer.device.PutOutboundElementsContainer(elemsContainer)
    94  		}
    95  	}
    96  	peer.SendStagedPackets()
    97  }
    98  
    99  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   100  	if !isRetry {
   101  		peer.timers.handshakeAttempts.Store(0)
   102  	}
   103  
   104  	peer.handshake.mutex.RLock()
   105  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   106  		peer.handshake.mutex.RUnlock()
   107  		return nil
   108  	}
   109  	peer.handshake.mutex.RUnlock()
   110  
   111  	peer.handshake.mutex.Lock()
   112  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   113  		peer.handshake.mutex.Unlock()
   114  		return nil
   115  	}
   116  	peer.handshake.lastSentHandshake = time.Now()
   117  	peer.handshake.mutex.Unlock()
   118  
   119  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   120  
   121  	msg, err := peer.device.CreateMessageInitiation(peer)
   122  	if err != nil {
   123  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   124  		return err
   125  	}
   126  
   127  	var buf [MessageInitiationSize]byte
   128  	writer := bytes.NewBuffer(buf[:0])
   129  	binary.Write(writer, binary.LittleEndian, msg)
   130  	packet := writer.Bytes()
   131  	peer.cookieGenerator.AddMacs(packet)
   132  
   133  	peer.timersAnyAuthenticatedPacketTraversal()
   134  	peer.timersAnyAuthenticatedPacketSent()
   135  
   136  	err = peer.SendBuffers([][]byte{packet})
   137  	if err != nil {
   138  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   139  	}
   140  	peer.timersHandshakeInitiated()
   141  
   142  	return err
   143  }
   144  
   145  func (peer *Peer) SendHandshakeResponse() error {
   146  	peer.handshake.mutex.Lock()
   147  	peer.handshake.lastSentHandshake = time.Now()
   148  	peer.handshake.mutex.Unlock()
   149  
   150  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   151  
   152  	response, err := peer.device.CreateMessageResponse(peer)
   153  	if err != nil {
   154  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   155  		return err
   156  	}
   157  
   158  	var buf [MessageResponseSize]byte
   159  	writer := bytes.NewBuffer(buf[:0])
   160  	binary.Write(writer, binary.LittleEndian, response)
   161  	packet := writer.Bytes()
   162  	peer.cookieGenerator.AddMacs(packet)
   163  
   164  	err = peer.BeginSymmetricSession()
   165  	if err != nil {
   166  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   167  		return err
   168  	}
   169  
   170  	peer.timersSessionDerived()
   171  	peer.timersAnyAuthenticatedPacketTraversal()
   172  	peer.timersAnyAuthenticatedPacketSent()
   173  
   174  	// TODO: allocation could be avoided
   175  	err = peer.SendBuffers([][]byte{packet})
   176  	if err != nil {
   177  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   178  	}
   179  	return err
   180  }
   181  
   182  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   183  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   184  
   185  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   186  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   187  	if err != nil {
   188  		device.log.Errorf("Failed to create cookie reply: %v", err)
   189  		return err
   190  	}
   191  
   192  	var buf [MessageCookieReplySize]byte
   193  	writer := bytes.NewBuffer(buf[:0])
   194  	binary.Write(writer, binary.LittleEndian, reply)
   195  	// TODO: allocation could be avoided
   196  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   197  	return nil
   198  }
   199  
   200  func (peer *Peer) keepKeyFreshSending() {
   201  	keypair := peer.keypairs.Current()
   202  	if keypair == nil {
   203  		return
   204  	}
   205  	nonce := keypair.sendNonce.Load()
   206  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   207  		peer.SendHandshakeInitiation(false)
   208  	}
   209  }
   210  
   211  func (device *Device) RoutineReadFromTUN() {
   212  	defer func() {
   213  		device.log.Verbosef("Routine: TUN reader - stopped")
   214  		device.state.stopping.Done()
   215  		device.queue.encryption.wg.Done()
   216  	}()
   217  
   218  	device.log.Verbosef("Routine: TUN reader - started")
   219  
   220  	var (
   221  		batchSize   = device.BatchSize()
   222  		readErr     error
   223  		elems       = make([]*QueueOutboundElement, batchSize)
   224  		bufs        = make([][]byte, batchSize)
   225  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   226  		count       = 0
   227  		sizes       = make([]int, batchSize)
   228  		offset      = MessageTransportHeaderSize
   229  	)
   230  
   231  	for i := range elems {
   232  		elems[i] = device.NewOutboundElement()
   233  		bufs[i] = elems[i].buffer[:]
   234  	}
   235  
   236  	defer func() {
   237  		for _, elem := range elems {
   238  			if elem != nil {
   239  				device.PutMessageBuffer(elem.buffer)
   240  				device.PutOutboundElement(elem)
   241  			}
   242  		}
   243  	}()
   244  
   245  	for {
   246  		// read packets
   247  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   248  		for i := 0; i < count; i++ {
   249  			if sizes[i] < 1 {
   250  				continue
   251  			}
   252  
   253  			elem := elems[i]
   254  			elem.packet = bufs[i][offset : offset+sizes[i]]
   255  
   256  			// lookup peer
   257  			var peer *Peer
   258  			switch elem.packet[0] >> 4 {
   259  			case 4:
   260  				if len(elem.packet) < ipv4.HeaderLen {
   261  					continue
   262  				}
   263  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   264  				peer = device.allowedips.Lookup(dst)
   265  
   266  			case 6:
   267  				if len(elem.packet) < ipv6.HeaderLen {
   268  					continue
   269  				}
   270  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   271  				peer = device.allowedips.Lookup(dst)
   272  
   273  			default:
   274  				device.log.Verbosef("Received packet with unknown IP version")
   275  			}
   276  
   277  			if peer == nil {
   278  				continue
   279  			}
   280  			elemsForPeer, ok := elemsByPeer[peer]
   281  			if !ok {
   282  				elemsForPeer = device.GetOutboundElementsContainer()
   283  				elemsByPeer[peer] = elemsForPeer
   284  			}
   285  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   286  			elems[i] = device.NewOutboundElement()
   287  			bufs[i] = elems[i].buffer[:]
   288  		}
   289  
   290  		for peer, elemsForPeer := range elemsByPeer {
   291  			if peer.isRunning.Load() {
   292  				peer.StagePackets(elemsForPeer)
   293  				peer.SendStagedPackets()
   294  			} else {
   295  				for _, elem := range elemsForPeer.elems {
   296  					device.PutMessageBuffer(elem.buffer)
   297  					device.PutOutboundElement(elem)
   298  				}
   299  				device.PutOutboundElementsContainer(elemsForPeer)
   300  			}
   301  			delete(elemsByPeer, peer)
   302  		}
   303  
   304  		if readErr != nil {
   305  			if errors.Is(readErr, tun.ErrTooManySegments) {
   306  				// TODO: record stat for this
   307  				// This will happen if MSS is surprisingly small (< 576)
   308  				// coincident with reasonably high throughput.
   309  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   310  				continue
   311  			}
   312  			if !device.isClosed() {
   313  				if !errors.Is(readErr, os.ErrClosed) {
   314  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   315  				}
   316  				go device.Close()
   317  			}
   318  			return
   319  		}
   320  	}
   321  }
   322  
   323  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   324  	for {
   325  		select {
   326  		case peer.queue.staged <- elems:
   327  			return
   328  		default:
   329  		}
   330  		select {
   331  		case tooOld := <-peer.queue.staged:
   332  			for _, elem := range tooOld.elems {
   333  				peer.device.PutMessageBuffer(elem.buffer)
   334  				peer.device.PutOutboundElement(elem)
   335  			}
   336  			peer.device.PutOutboundElementsContainer(tooOld)
   337  		default:
   338  		}
   339  	}
   340  }
   341  
   342  func (peer *Peer) SendStagedPackets() {
   343  top:
   344  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   345  		return
   346  	}
   347  
   348  	keypair := peer.keypairs.Current()
   349  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   350  		peer.SendHandshakeInitiation(false)
   351  		return
   352  	}
   353  
   354  	for {
   355  		var elemsContainerOOO *QueueOutboundElementsContainer
   356  		select {
   357  		case elemsContainer := <-peer.queue.staged:
   358  			i := 0
   359  			for _, elem := range elemsContainer.elems {
   360  				elem.peer = peer
   361  				elem.nonce = keypair.sendNonce.Add(1) - 1
   362  				if elem.nonce >= RejectAfterMessages {
   363  					keypair.sendNonce.Store(RejectAfterMessages)
   364  					if elemsContainerOOO == nil {
   365  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
   366  					}
   367  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   368  					continue
   369  				} else {
   370  					elemsContainer.elems[i] = elem
   371  					i++
   372  				}
   373  
   374  				elem.keypair = keypair
   375  			}
   376  			elemsContainer.Lock()
   377  			elemsContainer.elems = elemsContainer.elems[:i]
   378  
   379  			if elemsContainerOOO != nil {
   380  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   381  			}
   382  
   383  			if len(elemsContainer.elems) == 0 {
   384  				peer.device.PutOutboundElementsContainer(elemsContainer)
   385  				goto top
   386  			}
   387  
   388  			// add to parallel and sequential queue
   389  			if peer.isRunning.Load() {
   390  				peer.queue.outbound.c <- elemsContainer
   391  				peer.device.queue.encryption.c <- elemsContainer
   392  			} else {
   393  				for _, elem := range elemsContainer.elems {
   394  					peer.device.PutMessageBuffer(elem.buffer)
   395  					peer.device.PutOutboundElement(elem)
   396  				}
   397  				peer.device.PutOutboundElementsContainer(elemsContainer)
   398  			}
   399  
   400  			if elemsContainerOOO != nil {
   401  				goto top
   402  			}
   403  		default:
   404  			return
   405  		}
   406  	}
   407  }
   408  
   409  func (peer *Peer) FlushStagedPackets() {
   410  	for {
   411  		select {
   412  		case elemsContainer := <-peer.queue.staged:
   413  			for _, elem := range elemsContainer.elems {
   414  				peer.device.PutMessageBuffer(elem.buffer)
   415  				peer.device.PutOutboundElement(elem)
   416  			}
   417  			peer.device.PutOutboundElementsContainer(elemsContainer)
   418  		default:
   419  			return
   420  		}
   421  	}
   422  }
   423  
   424  func calculatePaddingSize(packetSize, mtu int) int {
   425  	lastUnit := packetSize
   426  	if mtu == 0 {
   427  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   428  	}
   429  	if lastUnit > mtu {
   430  		lastUnit %= mtu
   431  	}
   432  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   433  	if paddedSize > mtu {
   434  		paddedSize = mtu
   435  	}
   436  	return paddedSize - lastUnit
   437  }
   438  
   439  /* Encrypts the elements in the queue
   440   * and marks them for sequential consumption (by releasing the mutex)
   441   *
   442   * Obs. One instance per core
   443   */
   444  func (device *Device) RoutineEncryption(id int) {
   445  	var paddingZeros [PaddingMultiple]byte
   446  	var nonce [chacha20poly1305.NonceSize]byte
   447  
   448  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   449  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   450  
   451  	for elemsContainer := range device.queue.encryption.c {
   452  		for _, elem := range elemsContainer.elems {
   453  			// populate header fields
   454  			header := elem.buffer[:MessageTransportHeaderSize]
   455  
   456  			fieldType := header[0:4]
   457  			fieldReceiver := header[4:8]
   458  			fieldNonce := header[8:16]
   459  
   460  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   461  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   462  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   463  
   464  			// pad content to multiple of 16
   465  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   466  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   467  
   468  			// encrypt content and release to consumer
   469  
   470  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   471  			elem.packet = elem.keypair.send.Seal(
   472  				header,
   473  				nonce[:],
   474  				elem.packet,
   475  				nil,
   476  			)
   477  		}
   478  		elemsContainer.Unlock()
   479  	}
   480  }
   481  
   482  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   483  	device := peer.device
   484  	defer func() {
   485  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   486  		peer.stopping.Done()
   487  	}()
   488  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   489  
   490  	bufs := make([][]byte, 0, maxBatchSize)
   491  
   492  	for elemsContainer := range peer.queue.outbound.c {
   493  		bufs = bufs[:0]
   494  		if elemsContainer == nil {
   495  			return
   496  		}
   497  		if !peer.isRunning.Load() {
   498  			// peer has been stopped; return re-usable elems to the shared pool.
   499  			// This is an optimization only. It is possible for the peer to be stopped
   500  			// immediately after this check, in which case, elem will get processed.
   501  			// The timers and SendBuffers code are resilient to a few stragglers.
   502  			// TODO: rework peer shutdown order to ensure
   503  			// that we never accidentally keep timers alive longer than necessary.
   504  			elemsContainer.Lock()
   505  			for _, elem := range elemsContainer.elems {
   506  				device.PutMessageBuffer(elem.buffer)
   507  				device.PutOutboundElement(elem)
   508  			}
   509  			continue
   510  		}
   511  		dataSent := false
   512  		elemsContainer.Lock()
   513  		for _, elem := range elemsContainer.elems {
   514  			if len(elem.packet) != MessageKeepaliveSize {
   515  				dataSent = true
   516  			}
   517  			bufs = append(bufs, elem.packet)
   518  		}
   519  
   520  		peer.timersAnyAuthenticatedPacketTraversal()
   521  		peer.timersAnyAuthenticatedPacketSent()
   522  
   523  		err := peer.SendBuffers(bufs)
   524  		if dataSent {
   525  			peer.timersDataSent()
   526  		}
   527  		for _, elem := range elemsContainer.elems {
   528  			device.PutMessageBuffer(elem.buffer)
   529  			device.PutOutboundElement(elem)
   530  		}
   531  		device.PutOutboundElementsContainer(elemsContainer)
   532  		if err != nil {
   533  			var errGSO conn.ErrUDPGSODisabled
   534  			if errors.As(err, &errGSO) {
   535  				device.log.Verbosef(err.Error())
   536  				err = errGSO.RetryErr
   537  			}
   538  		}
   539  		if err != nil {
   540  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   541  			continue
   542  		}
   543  
   544  		peer.keepKeyFreshSending()
   545  	}
   546  }