github.com/GFW-knocker/wireguard@v1.0.1/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"crypto/rand"
    11  	"encoding/binary"
    12  	"errors"
    13  	"fmt"
    14  	"math/big"
    15  	"net"
    16  	"os"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/GFW-knocker/wireguard/conn"
    21  	"github.com/GFW-knocker/wireguard/tun"
    22  	"golang.org/x/crypto/chacha20poly1305"
    23  	"golang.org/x/net/ipv4"
    24  	"golang.org/x/net/ipv6"
    25  )
    26  
    27  /* Outbound flow
    28   *
    29   * 1. TUN queue
    30   * 2. Routing (sequential)
    31   * 3. Nonce assignment (sequential)
    32   * 4. Encryption (parallel)
    33   * 5. Transmission (sequential)
    34   *
    35   * The functions in this file occur (roughly) in the order in
    36   * which the packets are processed.
    37   *
    38   * Locking, Producers and Consumers
    39   *
    40   * The order of packets (per peer) must be maintained,
    41   * but encryption of packets happen out-of-order:
    42   *
    43   * The sequential consumers will attempt to take the lock,
    44   * workers release lock when they have completed work (encryption) on the packet.
    45   *
    46   * If the element is inserted into the "encryption queue",
    47   * the content is preceded by enough "junk" to contain the transport header
    48   * (to allow the construction of transport messages in-place)
    49   */
    50  
    51  type QueueOutboundElement struct {
    52  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    53  	packet  []byte                // slice of "buffer" (always!)
    54  	nonce   uint64                // nonce for encryption
    55  	keypair *Keypair              // keypair for encryption
    56  	peer    *Peer                 // related peer
    57  }
    58  
    59  type QueueOutboundElementsContainer struct {
    60  	sync.Mutex
    61  	elems []*QueueOutboundElement
    62  }
    63  
    64  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    65  	elem := device.GetOutboundElement()
    66  	elem.buffer = device.GetMessageBuffer()
    67  	elem.nonce = 0
    68  	// keypair and peer were cleared (if necessary) by clearPointers.
    69  	return elem
    70  }
    71  
    72  // clearPointers clears elem fields that contain pointers.
    73  // This makes the garbage collector's life easier and
    74  // avoids accidentally keeping other objects around unnecessarily.
    75  // It also reduces the possible collateral damage from use-after-free bugs.
    76  func (elem *QueueOutboundElement) clearPointers() {
    77  	elem.buffer = nil
    78  	elem.packet = nil
    79  	elem.keypair = nil
    80  	elem.peer = nil
    81  }
    82  
    83  /* Queues a keepalive if no packets are queued for peer
    84   */
    85  func (peer *Peer) SendKeepalive() {
    86  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
    87  		elem := peer.device.NewOutboundElement()
    88  		elemsContainer := peer.device.GetOutboundElementsContainer()
    89  		elemsContainer.elems = append(elemsContainer.elems, elem)
    90  		select {
    91  		case peer.queue.staged <- elemsContainer:
    92  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    93  		default:
    94  			peer.device.PutMessageBuffer(elem.buffer)
    95  			peer.device.PutOutboundElement(elem)
    96  			peer.device.PutOutboundElementsContainer(elemsContainer)
    97  		}
    98  	}
    99  	peer.SendStagedPackets()
   100  }
   101  
   102  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   103  	if !isRetry {
   104  		peer.timers.handshakeAttempts.Store(0)
   105  	}
   106  
   107  	peer.handshake.mutex.RLock()
   108  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   109  		peer.handshake.mutex.RUnlock()
   110  		return nil
   111  	}
   112  	peer.handshake.mutex.RUnlock()
   113  
   114  	peer.handshake.mutex.Lock()
   115  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   116  		peer.handshake.mutex.Unlock()
   117  		return nil
   118  	}
   119  	peer.handshake.lastSentHandshake = time.Now()
   120  	peer.handshake.mutex.Unlock()
   121  
   122  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   123  
   124  	msg, err := peer.device.CreateMessageInitiation(peer)
   125  	if err != nil {
   126  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   127  		return err
   128  	}
   129  
   130  	var buf [MessageInitiationSize]byte
   131  	writer := bytes.NewBuffer(buf[:0])
   132  	binary.Write(writer, binary.LittleEndian, msg)
   133  	packet := writer.Bytes()
   134  	peer.cookieGenerator.AddMacs(packet)
   135  
   136  	peer.timersAnyAuthenticatedPacketTraversal()
   137  	peer.timersAnyAuthenticatedPacketSent()
   138  
   139  	// -------------- GFW-knocker inspired by uoosef bepass-org ----------------
   140  	// send some random packet before handshake
   141  	//fmt.Println("wireguard handshake")
   142  
   143  	numPackets := randomInt(10, 20)
   144  	for i := 0; i < numPackets; i++ {
   145  		// Generate a random packet size between 10 and 40 bytes
   146  		packetSize := randomInt(40, 100)
   147  		randomPacket := make([]byte, packetSize)
   148  		_, err := rand.Read(randomPacket)
   149  		if err != nil {
   150  			return fmt.Errorf("error generating random packet: %v", err)
   151  		}
   152  
   153  		// Send the random packet
   154  		err = peer.SendBuffers([][]byte{randomPacket})
   155  		if err != nil {
   156  			return fmt.Errorf("error sending random packet: %v", err)
   157  		}
   158  
   159  		// Wait for a random duration between 10 and 50 milliseconds
   160  		time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond)
   161  	}
   162  	// ---------------------------------------------------------------
   163  
   164  	err = peer.SendBuffers([][]byte{packet})
   165  	if err != nil {
   166  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   167  	}
   168  	peer.timersHandshakeInitiated()
   169  
   170  	return err
   171  }
   172  
   173  // ---------------- by uoosef bepass-org ------------------------------------
   174  func randomInt(min, max int) int {
   175  	nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1)))
   176  	if err != nil {
   177  		panic(err)
   178  	}
   179  	return int(nBig.Int64()) + min
   180  }
   181  
   182  // --------------------------------------------------------------------------
   183  
   184  func (peer *Peer) SendHandshakeResponse() error {
   185  	peer.handshake.mutex.Lock()
   186  	peer.handshake.lastSentHandshake = time.Now()
   187  	peer.handshake.mutex.Unlock()
   188  
   189  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   190  
   191  	response, err := peer.device.CreateMessageResponse(peer)
   192  	if err != nil {
   193  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   194  		return err
   195  	}
   196  
   197  	var buf [MessageResponseSize]byte
   198  	writer := bytes.NewBuffer(buf[:0])
   199  	binary.Write(writer, binary.LittleEndian, response)
   200  	packet := writer.Bytes()
   201  	peer.cookieGenerator.AddMacs(packet)
   202  
   203  	err = peer.BeginSymmetricSession()
   204  	if err != nil {
   205  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   206  		return err
   207  	}
   208  
   209  	peer.timersSessionDerived()
   210  	peer.timersAnyAuthenticatedPacketTraversal()
   211  	peer.timersAnyAuthenticatedPacketSent()
   212  
   213  	// TODO: allocation could be avoided
   214  	err = peer.SendBuffers([][]byte{packet})
   215  	if err != nil {
   216  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   217  	}
   218  	return err
   219  }
   220  
   221  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   222  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   223  
   224  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   225  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   226  	if err != nil {
   227  		device.log.Errorf("Failed to create cookie reply: %v", err)
   228  		return err
   229  	}
   230  
   231  	var buf [MessageCookieReplySize]byte
   232  	writer := bytes.NewBuffer(buf[:0])
   233  	binary.Write(writer, binary.LittleEndian, reply)
   234  	// TODO: allocation could be avoided
   235  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   236  	return nil
   237  }
   238  
   239  func (peer *Peer) keepKeyFreshSending() {
   240  	keypair := peer.keypairs.Current()
   241  	if keypair == nil {
   242  		return
   243  	}
   244  	nonce := keypair.sendNonce.Load()
   245  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   246  		peer.SendHandshakeInitiation(false)
   247  	}
   248  }
   249  
   250  func (device *Device) RoutineReadFromTUN() {
   251  	defer func() {
   252  		device.log.Verbosef("Routine: TUN reader - stopped")
   253  		device.state.stopping.Done()
   254  		device.queue.encryption.wg.Done()
   255  	}()
   256  
   257  	device.log.Verbosef("Routine: TUN reader - started")
   258  
   259  	var (
   260  		batchSize   = device.BatchSize()
   261  		readErr     error
   262  		elems       = make([]*QueueOutboundElement, batchSize)
   263  		bufs        = make([][]byte, batchSize)
   264  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   265  		count       = 0
   266  		sizes       = make([]int, batchSize)
   267  		offset      = MessageTransportHeaderSize
   268  	)
   269  
   270  	for i := range elems {
   271  		elems[i] = device.NewOutboundElement()
   272  		bufs[i] = elems[i].buffer[:]
   273  	}
   274  
   275  	defer func() {
   276  		for _, elem := range elems {
   277  			if elem != nil {
   278  				device.PutMessageBuffer(elem.buffer)
   279  				device.PutOutboundElement(elem)
   280  			}
   281  		}
   282  	}()
   283  
   284  	for {
   285  		// read packets
   286  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   287  		for i := 0; i < count; i++ {
   288  			if sizes[i] < 1 {
   289  				continue
   290  			}
   291  
   292  			elem := elems[i]
   293  			elem.packet = bufs[i][offset : offset+sizes[i]]
   294  
   295  			// lookup peer
   296  			var peer *Peer
   297  			switch elem.packet[0] >> 4 {
   298  			case 4:
   299  				if len(elem.packet) < ipv4.HeaderLen {
   300  					continue
   301  				}
   302  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   303  				peer = device.allowedips.Lookup(dst)
   304  
   305  			case 6:
   306  				if len(elem.packet) < ipv6.HeaderLen {
   307  					continue
   308  				}
   309  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   310  				peer = device.allowedips.Lookup(dst)
   311  
   312  			default:
   313  				device.log.Verbosef("Received packet with unknown IP version")
   314  			}
   315  
   316  			if peer == nil {
   317  				continue
   318  			}
   319  			elemsForPeer, ok := elemsByPeer[peer]
   320  			if !ok {
   321  				elemsForPeer = device.GetOutboundElementsContainer()
   322  				elemsByPeer[peer] = elemsForPeer
   323  			}
   324  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   325  			elems[i] = device.NewOutboundElement()
   326  			bufs[i] = elems[i].buffer[:]
   327  		}
   328  
   329  		for peer, elemsForPeer := range elemsByPeer {
   330  			if peer.isRunning.Load() {
   331  				peer.StagePackets(elemsForPeer)
   332  				peer.SendStagedPackets()
   333  			} else {
   334  				for _, elem := range elemsForPeer.elems {
   335  					device.PutMessageBuffer(elem.buffer)
   336  					device.PutOutboundElement(elem)
   337  				}
   338  				device.PutOutboundElementsContainer(elemsForPeer)
   339  			}
   340  			delete(elemsByPeer, peer)
   341  		}
   342  
   343  		if readErr != nil {
   344  			if errors.Is(readErr, tun.ErrTooManySegments) {
   345  				// TODO: record stat for this
   346  				// This will happen if MSS is surprisingly small (< 576)
   347  				// coincident with reasonably high throughput.
   348  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   349  				continue
   350  			}
   351  			if !device.isClosed() {
   352  				if !errors.Is(readErr, os.ErrClosed) {
   353  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   354  				}
   355  				go device.Close()
   356  			}
   357  			return
   358  		}
   359  	}
   360  }
   361  
   362  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   363  	for {
   364  		select {
   365  		case peer.queue.staged <- elems:
   366  			return
   367  		default:
   368  		}
   369  		select {
   370  		case tooOld := <-peer.queue.staged:
   371  			for _, elem := range tooOld.elems {
   372  				peer.device.PutMessageBuffer(elem.buffer)
   373  				peer.device.PutOutboundElement(elem)
   374  			}
   375  			peer.device.PutOutboundElementsContainer(tooOld)
   376  		default:
   377  		}
   378  	}
   379  }
   380  
   381  func (peer *Peer) SendStagedPackets() {
   382  top:
   383  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   384  		return
   385  	}
   386  
   387  	keypair := peer.keypairs.Current()
   388  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   389  		peer.SendHandshakeInitiation(false)
   390  		return
   391  	}
   392  
   393  	for {
   394  		var elemsContainerOOO *QueueOutboundElementsContainer
   395  		select {
   396  		case elemsContainer := <-peer.queue.staged:
   397  			i := 0
   398  			for _, elem := range elemsContainer.elems {
   399  				elem.peer = peer
   400  				elem.nonce = keypair.sendNonce.Add(1) - 1
   401  				if elem.nonce >= RejectAfterMessages {
   402  					keypair.sendNonce.Store(RejectAfterMessages)
   403  					if elemsContainerOOO == nil {
   404  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
   405  					}
   406  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   407  					continue
   408  				} else {
   409  					elemsContainer.elems[i] = elem
   410  					i++
   411  				}
   412  
   413  				elem.keypair = keypair
   414  			}
   415  			elemsContainer.Lock()
   416  			elemsContainer.elems = elemsContainer.elems[:i]
   417  
   418  			if elemsContainerOOO != nil {
   419  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   420  			}
   421  
   422  			if len(elemsContainer.elems) == 0 {
   423  				peer.device.PutOutboundElementsContainer(elemsContainer)
   424  				goto top
   425  			}
   426  
   427  			// add to parallel and sequential queue
   428  			if peer.isRunning.Load() {
   429  				peer.queue.outbound.c <- elemsContainer
   430  				peer.device.queue.encryption.c <- elemsContainer
   431  			} else {
   432  				for _, elem := range elemsContainer.elems {
   433  					peer.device.PutMessageBuffer(elem.buffer)
   434  					peer.device.PutOutboundElement(elem)
   435  				}
   436  				peer.device.PutOutboundElementsContainer(elemsContainer)
   437  			}
   438  
   439  			if elemsContainerOOO != nil {
   440  				goto top
   441  			}
   442  		default:
   443  			return
   444  		}
   445  	}
   446  }
   447  
   448  func (peer *Peer) FlushStagedPackets() {
   449  	for {
   450  		select {
   451  		case elemsContainer := <-peer.queue.staged:
   452  			for _, elem := range elemsContainer.elems {
   453  				peer.device.PutMessageBuffer(elem.buffer)
   454  				peer.device.PutOutboundElement(elem)
   455  			}
   456  			peer.device.PutOutboundElementsContainer(elemsContainer)
   457  		default:
   458  			return
   459  		}
   460  	}
   461  }
   462  
   463  func calculatePaddingSize(packetSize, mtu int) int {
   464  	lastUnit := packetSize
   465  	if mtu == 0 {
   466  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   467  	}
   468  	if lastUnit > mtu {
   469  		lastUnit %= mtu
   470  	}
   471  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   472  	if paddedSize > mtu {
   473  		paddedSize = mtu
   474  	}
   475  	return paddedSize - lastUnit
   476  }
   477  
   478  /* Encrypts the elements in the queue
   479   * and marks them for sequential consumption (by releasing the mutex)
   480   *
   481   * Obs. One instance per core
   482   */
   483  func (device *Device) RoutineEncryption(id int) {
   484  	var paddingZeros [PaddingMultiple]byte
   485  	var nonce [chacha20poly1305.NonceSize]byte
   486  
   487  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   488  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   489  
   490  	for elemsContainer := range device.queue.encryption.c {
   491  		for _, elem := range elemsContainer.elems {
   492  			// populate header fields
   493  			header := elem.buffer[:MessageTransportHeaderSize]
   494  
   495  			fieldType := header[0:4]
   496  			fieldReceiver := header[4:8]
   497  			fieldNonce := header[8:16]
   498  
   499  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   500  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   501  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   502  
   503  			// pad content to multiple of 16
   504  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   505  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   506  
   507  			// encrypt content and release to consumer
   508  
   509  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   510  			elem.packet = elem.keypair.send.Seal(
   511  				header,
   512  				nonce[:],
   513  				elem.packet,
   514  				nil,
   515  			)
   516  		}
   517  		elemsContainer.Unlock()
   518  	}
   519  }
   520  
   521  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   522  	device := peer.device
   523  	defer func() {
   524  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   525  		peer.stopping.Done()
   526  	}()
   527  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   528  
   529  	bufs := make([][]byte, 0, maxBatchSize)
   530  
   531  	for elemsContainer := range peer.queue.outbound.c {
   532  		bufs = bufs[:0]
   533  		if elemsContainer == nil {
   534  			return
   535  		}
   536  		if !peer.isRunning.Load() {
   537  			// peer has been stopped; return re-usable elems to the shared pool.
   538  			// This is an optimization only. It is possible for the peer to be stopped
   539  			// immediately after this check, in which case, elem will get processed.
   540  			// The timers and SendBuffers code are resilient to a few stragglers.
   541  			// TODO: rework peer shutdown order to ensure
   542  			// that we never accidentally keep timers alive longer than necessary.
   543  			elemsContainer.Lock()
   544  			for _, elem := range elemsContainer.elems {
   545  				device.PutMessageBuffer(elem.buffer)
   546  				device.PutOutboundElement(elem)
   547  			}
   548  			continue
   549  		}
   550  		dataSent := false
   551  		elemsContainer.Lock()
   552  		for _, elem := range elemsContainer.elems {
   553  			if len(elem.packet) != MessageKeepaliveSize {
   554  				dataSent = true
   555  			}
   556  			bufs = append(bufs, elem.packet)
   557  		}
   558  
   559  		peer.timersAnyAuthenticatedPacketTraversal()
   560  		peer.timersAnyAuthenticatedPacketSent()
   561  
   562  		err := peer.SendBuffers(bufs)
   563  		if dataSent {
   564  			peer.timersDataSent()
   565  		}
   566  		for _, elem := range elemsContainer.elems {
   567  			device.PutMessageBuffer(elem.buffer)
   568  			device.PutOutboundElement(elem)
   569  		}
   570  		device.PutOutboundElementsContainer(elemsContainer)
   571  		if err != nil {
   572  			var errGSO conn.ErrUDPGSODisabled
   573  			if errors.As(err, &errGSO) {
   574  				device.log.Verbosef(err.Error())
   575  				err = errGSO.RetryErr
   576  			}
   577  		}
   578  		if err != nil {
   579  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   580  			continue
   581  		}
   582  
   583  		peer.keepKeyFreshSending()
   584  	}
   585  }