github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"crypto/rand"
    11  	"encoding/binary"
    12  	"errors"
    13  	"math/big"
    14  	"net"
    15  	"os"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/bepass-org/wireguard-go/conn"
    20  	"github.com/bepass-org/wireguard-go/tun"
    21  	"golang.org/x/crypto/chacha20poly1305"
    22  	"golang.org/x/net/ipv4"
    23  	"golang.org/x/net/ipv6"
    24  )
    25  
    26  /* Outbound flow
    27   *
    28   * 1. TUN queue
    29   * 2. Routing (sequential)
    30   * 3. Nonce assignment (sequential)
    31   * 4. Encryption (parallel)
    32   * 5. Transmission (sequential)
    33   *
    34   * The functions in this file occur (roughly) in the order in
    35   * which the packets are processed.
    36   *
    37   * Locking, Producers and Consumers
    38   *
    39   * The order of packets (per peer) must be maintained,
    40   * but encryption of packets happen out-of-order:
    41   *
    42   * The sequential consumers will attempt to take the lock,
    43   * workers release lock when they have completed work (encryption) on the packet.
    44   *
    45   * If the element is inserted into the "encryption queue",
    46   * the content is preceded by enough "junk" to contain the transport header
    47   * (to allow the construction of transport messages in-place)
    48   */
    49  
    50  type QueueOutboundElement struct {
    51  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    52  	packet  []byte                // slice of "buffer" (always!)
    53  	nonce   uint64                // nonce for encryption
    54  	keypair *Keypair              // keypair for encryption
    55  	peer    *Peer                 // related peer
    56  }
    57  
    58  type QueueOutboundElementsContainer struct {
    59  	sync.Mutex
    60  	elems []*QueueOutboundElement
    61  }
    62  
    63  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    64  	elem := device.GetOutboundElement()
    65  	elem.buffer = device.GetMessageBuffer()
    66  	elem.nonce = 0
    67  	// keypair and peer were cleared (if necessary) by clearPointers.
    68  	return elem
    69  }
    70  
    71  // clearPointers clears elem fields that contain pointers.
    72  // This makes the garbage collector's life easier and
    73  // avoids accidentally keeping other objects around unnecessarily.
    74  // It also reduces the possible collateral damage from use-after-free bugs.
    75  func (elem *QueueOutboundElement) clearPointers() {
    76  	elem.buffer = nil
    77  	elem.packet = nil
    78  	elem.keypair = nil
    79  	elem.peer = nil
    80  }
    81  
    82  func randomInt(min, max int) int {
    83  	nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1)))
    84  	if err != nil {
    85  		panic(err)
    86  	}
    87  	return int(nBig.Int64()) + min
    88  }
    89  
    90  func (peer *Peer) sendRandomPackets() {
    91  	// Generate a random number of packets between 5 and 10
    92  	numPackets := randomInt(8, 15)
    93  	for i := 0; i < numPackets; i++ {
    94  		// Generate a random packet size between 10 and 40 bytes
    95  		packetSize := randomInt(40, 100)
    96  		randomPacket := make([]byte, packetSize)
    97  		_, err := rand.Read(randomPacket)
    98  		if err != nil {
    99  			return
   100  		}
   101  
   102  		// Send the random packet
   103  		err = peer.SendBuffers([][]byte{randomPacket})
   104  		if err != nil {
   105  			return
   106  		}
   107  
   108  		if i < numPackets-1 && peer.isRunning.Load() && !peer.device.isClosed() {
   109  			select {
   110  			case <-peer.stopCh:
   111  			// Wait for a random duration between 20 and 250 milliseconds
   112  			case <-time.After(time.Duration(randomInt(20, 250)) * time.Millisecond):
   113  			}
   114  		}
   115  	}
   116  }
   117  
   118  /* Queues a keepalive if no packets are queued for peer
   119   */
   120  func (peer *Peer) SendKeepalive() {
   121  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
   122  		// Send some random packets on every keepalive
   123  		if peer.trick {
   124  			peer.sendRandomPackets()
   125  		}
   126  
   127  		elem := peer.device.NewOutboundElement()
   128  		elemsContainer := peer.device.GetOutboundElementsContainer()
   129  		elemsContainer.elems = append(elemsContainer.elems, elem)
   130  		select {
   131  		case peer.queue.staged <- elemsContainer:
   132  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
   133  		default:
   134  			peer.device.PutMessageBuffer(elem.buffer)
   135  			peer.device.PutOutboundElement(elem)
   136  			peer.device.PutOutboundElementsContainer(elemsContainer)
   137  		}
   138  	}
   139  	peer.SendStagedPackets()
   140  }
   141  
   142  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   143  	if !isRetry {
   144  		peer.timers.handshakeAttempts.Store(0)
   145  	}
   146  
   147  	peer.handshake.mutex.RLock()
   148  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   149  		peer.handshake.mutex.RUnlock()
   150  		return nil
   151  	}
   152  	peer.handshake.mutex.RUnlock()
   153  
   154  	peer.handshake.mutex.Lock()
   155  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   156  		peer.handshake.mutex.Unlock()
   157  		return nil
   158  	}
   159  
   160  	// send some random packets on handshake
   161  	if peer.trick {
   162  		peer.sendRandomPackets()
   163  	}
   164  
   165  	peer.handshake.lastSentHandshake = time.Now()
   166  	peer.handshake.mutex.Unlock()
   167  
   168  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   169  
   170  	msg, err := peer.device.CreateMessageInitiation(peer)
   171  	if err != nil {
   172  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   173  		return err
   174  	}
   175  
   176  	var buf [MessageInitiationSize]byte
   177  	writer := bytes.NewBuffer(buf[:0])
   178  	binary.Write(writer, binary.LittleEndian, msg)
   179  	packet := writer.Bytes()
   180  	peer.cookieGenerator.AddMacs(packet)
   181  
   182  	peer.timersAnyAuthenticatedPacketTraversal()
   183  	peer.timersAnyAuthenticatedPacketSent()
   184  
   185  	err = peer.SendBuffers([][]byte{packet})
   186  	if err != nil {
   187  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   188  	}
   189  	peer.timersHandshakeInitiated()
   190  
   191  	return err
   192  }
   193  
   194  func (peer *Peer) SendHandshakeResponse() error {
   195  	peer.handshake.mutex.Lock()
   196  	peer.handshake.lastSentHandshake = time.Now()
   197  	peer.handshake.mutex.Unlock()
   198  
   199  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   200  
   201  	response, err := peer.device.CreateMessageResponse(peer)
   202  	if err != nil {
   203  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   204  		return err
   205  	}
   206  
   207  	var buf [MessageResponseSize]byte
   208  	writer := bytes.NewBuffer(buf[:0])
   209  	binary.Write(writer, binary.LittleEndian, response)
   210  	packet := writer.Bytes()
   211  	peer.cookieGenerator.AddMacs(packet)
   212  
   213  	err = peer.BeginSymmetricSession()
   214  	if err != nil {
   215  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   216  		return err
   217  	}
   218  
   219  	peer.timersSessionDerived()
   220  	peer.timersAnyAuthenticatedPacketTraversal()
   221  	peer.timersAnyAuthenticatedPacketSent()
   222  
   223  	// TODO: allocation could be avoided
   224  	err = peer.SendBuffers([][]byte{packet})
   225  	if err != nil {
   226  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   227  	}
   228  	return err
   229  }
   230  
   231  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   232  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   233  
   234  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   235  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   236  	if err != nil {
   237  		device.log.Errorf("Failed to create cookie reply: %v", err)
   238  		return err
   239  	}
   240  
   241  	var buf [MessageCookieReplySize]byte
   242  	writer := bytes.NewBuffer(buf[:0])
   243  	binary.Write(writer, binary.LittleEndian, reply)
   244  	// TODO: allocation could be avoided
   245  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   246  	return nil
   247  }
   248  
   249  func (peer *Peer) keepKeyFreshSending() {
   250  	keypair := peer.keypairs.Current()
   251  	if keypair == nil {
   252  		return
   253  	}
   254  	nonce := keypair.sendNonce.Load()
   255  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   256  		peer.SendHandshakeInitiation(false)
   257  	}
   258  }
   259  
   260  func (device *Device) RoutineReadFromTUN() {
   261  	defer func() {
   262  		device.log.Verbosef("Routine: TUN reader - stopped")
   263  		device.state.stopping.Done()
   264  		device.queue.encryption.wg.Done()
   265  	}()
   266  
   267  	device.log.Verbosef("Routine: TUN reader - started")
   268  
   269  	var (
   270  		batchSize   = device.BatchSize()
   271  		readErr     error
   272  		elems       = make([]*QueueOutboundElement, batchSize)
   273  		bufs        = make([][]byte, batchSize)
   274  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   275  		count       = 0
   276  		sizes       = make([]int, batchSize)
   277  		offset      = MessageTransportHeaderSize
   278  	)
   279  
   280  	for i := range elems {
   281  		elems[i] = device.NewOutboundElement()
   282  		bufs[i] = elems[i].buffer[:]
   283  	}
   284  
   285  	defer func() {
   286  		for _, elem := range elems {
   287  			if elem != nil {
   288  				device.PutMessageBuffer(elem.buffer)
   289  				device.PutOutboundElement(elem)
   290  			}
   291  		}
   292  	}()
   293  
   294  	for {
   295  		// read packets
   296  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   297  		for i := 0; i < count; i++ {
   298  			if sizes[i] < 1 {
   299  				continue
   300  			}
   301  
   302  			elem := elems[i]
   303  			elem.packet = bufs[i][offset : offset+sizes[i]]
   304  
   305  			// lookup peer
   306  			var peer *Peer
   307  			switch elem.packet[0] >> 4 {
   308  			case 4:
   309  				if len(elem.packet) < ipv4.HeaderLen {
   310  					continue
   311  				}
   312  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   313  				peer = device.allowedips.Lookup(dst)
   314  
   315  			case 6:
   316  				if len(elem.packet) < ipv6.HeaderLen {
   317  					continue
   318  				}
   319  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   320  				peer = device.allowedips.Lookup(dst)
   321  
   322  			default:
   323  				device.log.Verbosef("Received packet with unknown IP version")
   324  			}
   325  
   326  			if peer == nil {
   327  				continue
   328  			}
   329  			elemsForPeer, ok := elemsByPeer[peer]
   330  			if !ok {
   331  				elemsForPeer = device.GetOutboundElementsContainer()
   332  				elemsByPeer[peer] = elemsForPeer
   333  			}
   334  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   335  			elems[i] = device.NewOutboundElement()
   336  			bufs[i] = elems[i].buffer[:]
   337  		}
   338  
   339  		for peer, elemsForPeer := range elemsByPeer {
   340  			if peer.isRunning.Load() {
   341  				peer.StagePackets(elemsForPeer)
   342  				peer.SendStagedPackets()
   343  			} else {
   344  				for _, elem := range elemsForPeer.elems {
   345  					device.PutMessageBuffer(elem.buffer)
   346  					device.PutOutboundElement(elem)
   347  				}
   348  				device.PutOutboundElementsContainer(elemsForPeer)
   349  			}
   350  			delete(elemsByPeer, peer)
   351  		}
   352  
   353  		if readErr != nil {
   354  			if errors.Is(readErr, tun.ErrTooManySegments) {
   355  				// TODO: record stat for this
   356  				// This will happen if MSS is surprisingly small (< 576)
   357  				// coincident with reasonably high throughput.
   358  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   359  				continue
   360  			}
   361  			if !device.isClosed() {
   362  				if !errors.Is(readErr, os.ErrClosed) {
   363  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   364  				}
   365  				go device.Close()
   366  			}
   367  			return
   368  		}
   369  	}
   370  }
   371  
   372  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   373  	for {
   374  		select {
   375  		case peer.queue.staged <- elems:
   376  			return
   377  		default:
   378  		}
   379  		select {
   380  		case tooOld := <-peer.queue.staged:
   381  			for _, elem := range tooOld.elems {
   382  				peer.device.PutMessageBuffer(elem.buffer)
   383  				peer.device.PutOutboundElement(elem)
   384  			}
   385  			peer.device.PutOutboundElementsContainer(tooOld)
   386  		default:
   387  		}
   388  	}
   389  }
   390  
   391  func (peer *Peer) SendStagedPackets() {
   392  top:
   393  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   394  		return
   395  	}
   396  
   397  	keypair := peer.keypairs.Current()
   398  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   399  		peer.SendHandshakeInitiation(false)
   400  		return
   401  	}
   402  
   403  	for {
   404  		var elemsContainerOOO *QueueOutboundElementsContainer
   405  		select {
   406  		case elemsContainer := <-peer.queue.staged:
   407  			i := 0
   408  			for _, elem := range elemsContainer.elems {
   409  				elem.peer = peer
   410  				elem.nonce = keypair.sendNonce.Add(1) - 1
   411  				if elem.nonce >= RejectAfterMessages {
   412  					keypair.sendNonce.Store(RejectAfterMessages)
   413  					if elemsContainerOOO == nil {
   414  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
   415  					}
   416  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   417  					continue
   418  				} else {
   419  					elemsContainer.elems[i] = elem
   420  					i++
   421  				}
   422  
   423  				elem.keypair = keypair
   424  			}
   425  			elemsContainer.Lock()
   426  			elemsContainer.elems = elemsContainer.elems[:i]
   427  
   428  			if elemsContainerOOO != nil {
   429  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   430  			}
   431  
   432  			if len(elemsContainer.elems) == 0 {
   433  				peer.device.PutOutboundElementsContainer(elemsContainer)
   434  				goto top
   435  			}
   436  
   437  			// add to parallel and sequential queue
   438  			if peer.isRunning.Load() {
   439  				peer.queue.outbound.c <- elemsContainer
   440  				peer.device.queue.encryption.c <- elemsContainer
   441  			} else {
   442  				for _, elem := range elemsContainer.elems {
   443  					peer.device.PutMessageBuffer(elem.buffer)
   444  					peer.device.PutOutboundElement(elem)
   445  				}
   446  				peer.device.PutOutboundElementsContainer(elemsContainer)
   447  			}
   448  
   449  			if elemsContainerOOO != nil {
   450  				goto top
   451  			}
   452  		default:
   453  			return
   454  		}
   455  	}
   456  }
   457  
   458  func (peer *Peer) FlushStagedPackets() {
   459  	for {
   460  		select {
   461  		case elemsContainer := <-peer.queue.staged:
   462  			for _, elem := range elemsContainer.elems {
   463  				peer.device.PutMessageBuffer(elem.buffer)
   464  				peer.device.PutOutboundElement(elem)
   465  			}
   466  			peer.device.PutOutboundElementsContainer(elemsContainer)
   467  		default:
   468  			return
   469  		}
   470  	}
   471  }
   472  
   473  func calculatePaddingSize(packetSize, mtu int) int {
   474  	lastUnit := packetSize
   475  	if mtu == 0 {
   476  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   477  	}
   478  	if lastUnit > mtu {
   479  		lastUnit %= mtu
   480  	}
   481  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   482  	if paddedSize > mtu {
   483  		paddedSize = mtu
   484  	}
   485  	return paddedSize - lastUnit
   486  }
   487  
   488  /* Encrypts the elements in the queue
   489   * and marks them for sequential consumption (by releasing the mutex)
   490   *
   491   * Obs. One instance per core
   492   */
   493  func (device *Device) RoutineEncryption(id int) {
   494  	var paddingZeros [PaddingMultiple]byte
   495  	var nonce [chacha20poly1305.NonceSize]byte
   496  
   497  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   498  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   499  
   500  	for elemsContainer := range device.queue.encryption.c {
   501  		for _, elem := range elemsContainer.elems {
   502  			// populate header fields
   503  			header := elem.buffer[:MessageTransportHeaderSize]
   504  
   505  			fieldType := header[0:4]
   506  			fieldReceiver := header[4:8]
   507  			fieldNonce := header[8:16]
   508  
   509  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   510  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   511  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   512  
   513  			// pad content to multiple of 16
   514  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   515  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   516  
   517  			// encrypt content and release to consumer
   518  
   519  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   520  			elem.packet = elem.keypair.send.Seal(
   521  				header,
   522  				nonce[:],
   523  				elem.packet,
   524  				nil,
   525  			)
   526  		}
   527  		elemsContainer.Unlock()
   528  	}
   529  }
   530  
   531  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   532  	device := peer.device
   533  	defer func() {
   534  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   535  		peer.stopping.Done()
   536  	}()
   537  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   538  
   539  	bufs := make([][]byte, 0, maxBatchSize)
   540  
   541  	for elemsContainer := range peer.queue.outbound.c {
   542  		bufs = bufs[:0]
   543  		if elemsContainer == nil {
   544  			return
   545  		}
   546  		if !peer.isRunning.Load() {
   547  			// peer has been stopped; return re-usable elems to the shared pool.
   548  			// This is an optimization only. It is possible for the peer to be stopped
   549  			// immediately after this check, in which case, elem will get processed.
   550  			// The timers and SendBuffers code are resilient to a few stragglers.
   551  			// TODO: rework peer shutdown order to ensure
   552  			// that we never accidentally keep timers alive longer than necessary.
   553  			elemsContainer.Lock()
   554  			for _, elem := range elemsContainer.elems {
   555  				device.PutMessageBuffer(elem.buffer)
   556  				device.PutOutboundElement(elem)
   557  			}
   558  			continue
   559  		}
   560  		dataSent := false
   561  		elemsContainer.Lock()
   562  		for _, elem := range elemsContainer.elems {
   563  			if len(elem.packet) != MessageKeepaliveSize {
   564  				dataSent = true
   565  			}
   566  			bufs = append(bufs, elem.packet)
   567  		}
   568  
   569  		peer.timersAnyAuthenticatedPacketTraversal()
   570  		peer.timersAnyAuthenticatedPacketSent()
   571  
   572  		err := peer.SendBuffers(bufs)
   573  		if dataSent {
   574  			peer.timersDataSent()
   575  		}
   576  		for _, elem := range elemsContainer.elems {
   577  			device.PutMessageBuffer(elem.buffer)
   578  			device.PutOutboundElement(elem)
   579  		}
   580  		device.PutOutboundElementsContainer(elemsContainer)
   581  		if err != nil {
   582  			var errGSO conn.ErrUDPGSODisabled
   583  			if errors.As(err, &errGSO) {
   584  				device.log.Verbosef(err.Error())
   585  				err = errGSO.RetryErr
   586  			}
   587  		}
   588  		if err != nil {
   589  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   590  			continue
   591  		}
   592  
   593  		peer.keepKeyFreshSending()
   594  	}
   595  }