gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"os"
    14  	"sync"
    15  	"time"
    16  
    17  	"gitee.com/aurawing/surguard-go/conn"
    18  	"gitee.com/aurawing/surguard-go/tun"
    19  	"golang.org/x/crypto/chacha20poly1305"
    20  	"golang.org/x/net/ipv4"
    21  	"golang.org/x/net/ipv6"
    22  )
    23  
    24  /* Outbound flow
    25   *
    26   * 1. TUN queue
    27   * 2. Routing (sequential)
    28   * 3. Nonce assignment (sequential)
    29   * 4. Encryption (parallel)
    30   * 5. Transmission (sequential)
    31   *
    32   * The functions in this file occur (roughly) in the order in
    33   * which the packets are processed.
    34   *
    35   * Locking, Producers and Consumers
    36   *
    37   * The order of packets (per peer) must be maintained,
    38   * but encryption of packets happen out-of-order:
    39   *
    40   * The sequential consumers will attempt to take the lock,
    41   * workers release lock when they have completed work (encryption) on the packet.
    42   *
    43   * If the element is inserted into the "encryption queue",
    44   * the content is preceded by enough "junk" to contain the transport header
    45   * (to allow the construction of transport messages in-place)
    46   */
    47  
    48  type QueueOutboundElement struct {
    49  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    50  	packet  []byte                // slice of "buffer" (always!)
    51  	nonce   uint64                // nonce for encryption
    52  	keypair *Keypair              // keypair for encryption
    53  	peer    *Peer                 // related peer
    54  }
    55  
    56  type QueueOutboundElementsContainer struct {
    57  	sync.Mutex
    58  	elems []*QueueOutboundElement
    59  }
    60  
    61  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    62  	elem := device.GetOutboundElement()
    63  	elem.buffer = device.GetMessageBuffer()
    64  	elem.nonce = 0
    65  	// keypair and peer were cleared (if necessary) by clearPointers.
    66  	return elem
    67  }
    68  
    69  // clearPointers clears elem fields that contain pointers.
    70  // This makes the garbage collector's life easier and
    71  // avoids accidentally keeping other objects around unnecessarily.
    72  // It also reduces the possible collateral damage from use-after-free bugs.
    73  func (elem *QueueOutboundElement) clearPointers() {
    74  	elem.buffer = nil
    75  	elem.packet = nil
    76  	elem.keypair = nil
    77  	elem.peer = nil
    78  }
    79  
    80  /* Queues a keepalive if no packets are queued for peer
    81   */
    82  func (peer *Peer) SendKeepalive() {
    83  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
    84  		elem := peer.device.NewOutboundElement()
    85  		elemsContainer := peer.device.GetOutboundElementsContainer()
    86  		elemsContainer.elems = append(elemsContainer.elems, elem)
    87  		select {
    88  		case peer.queue.staged <- elemsContainer:
    89  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    90  		default:
    91  			peer.device.PutMessageBuffer(elem.buffer)
    92  			peer.device.PutOutboundElement(elem)
    93  			peer.device.PutOutboundElementsContainer(elemsContainer)
    94  		}
    95  	}
    96  	peer.SendStagedPackets()
    97  }
    98  
    99  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
   100  	if !isRetry {
   101  		peer.timers.handshakeAttempts.Store(0)
   102  	}
   103  
   104  	peer.handshake.mutex.RLock()
   105  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   106  		peer.handshake.mutex.RUnlock()
   107  		return nil
   108  	}
   109  	peer.handshake.mutex.RUnlock()
   110  
   111  	peer.handshake.mutex.Lock()
   112  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   113  		peer.handshake.mutex.Unlock()
   114  		return nil
   115  	}
   116  	peer.handshake.lastSentHandshake = time.Now()
   117  	peer.handshake.mutex.Unlock()
   118  
   119  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   120  
   121  	msg, err := peer.device.CreateMessageInitiation(peer)
   122  	if err != nil {
   123  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   124  		return err
   125  	}
   126  
   127  	var buf [MessageInitiationSize]byte
   128  	writer := bytes.NewBuffer(buf[:0])
   129  	binary.Write(writer, binary.LittleEndian, msg)
   130  	packet := writer.Bytes()
   131  	peer.cookieGenerator.AddMacs(packet)
   132  
   133  	peer.timersAnyAuthenticatedPacketTraversal()
   134  	peer.timersAnyAuthenticatedPacketSent()
   135  
   136  	err = peer.SendBuffers([][]byte{packet})
   137  	if err != nil {
   138  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   139  	}
   140  	peer.timersHandshakeInitiated()
   141  
   142  	return err
   143  }
   144  
   145  func (peer *Peer) SendHandshakeResponse() error {
   146  	peer.handshake.mutex.Lock()
   147  	peer.handshake.lastSentHandshake = time.Now()
   148  	peer.handshake.mutex.Unlock()
   149  
   150  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   151  
   152  	response, err := peer.device.CreateMessageResponse(peer)
   153  	if err != nil {
   154  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   155  		return err
   156  	}
   157  
   158  	var buf [MessageResponseSize]byte
   159  	writer := bytes.NewBuffer(buf[:0])
   160  	binary.Write(writer, binary.LittleEndian, response)
   161  	packet := writer.Bytes()
   162  	peer.cookieGenerator.AddMacs(packet)
   163  
   164  	err = peer.BeginSymmetricSession()
   165  	if err != nil {
   166  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   167  		return err
   168  	}
   169  
   170  	peer.timersSessionDerived()
   171  	peer.timersAnyAuthenticatedPacketTraversal()
   172  	peer.timersAnyAuthenticatedPacketSent()
   173  
   174  	// TODO: allocation could be avoided
   175  	err = peer.SendBuffers([][]byte{packet})
   176  	if err != nil {
   177  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   178  	}
   179  	return err
   180  }
   181  
   182  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   183  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   184  
   185  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   186  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   187  	if err != nil {
   188  		device.log.Errorf("Failed to create cookie reply: %v", err)
   189  		return err
   190  	}
   191  
   192  	var buf [MessageCookieReplySize]byte
   193  	writer := bytes.NewBuffer(buf[:0])
   194  	binary.Write(writer, binary.LittleEndian, reply)
   195  	// TODO: allocation could be avoided
   196  	device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
   197  	return nil
   198  }
   199  
   200  func (peer *Peer) keepKeyFreshSending() {
   201  	keypair := peer.keypairs.Current()
   202  	if keypair == nil {
   203  		return
   204  	}
   205  	nonce := keypair.sendNonce.Load()
   206  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   207  		peer.SendHandshakeInitiation(false)
   208  	}
   209  }
   210  
   211  func (device *Device) RoutineReadFromTUN() {
   212  	defer func() {
   213  		device.log.Verbosef("Routine: TUN reader - stopped")
   214  		device.state.stopping.Done()
   215  		device.queue.encryption.wg.Done()
   216  	}()
   217  
   218  	device.log.Verbosef("Routine: TUN reader - started")
   219  
   220  	var (
   221  		batchSize   = device.BatchSize()
   222  		readErr     error
   223  		elems       = make([]*QueueOutboundElement, batchSize)
   224  		bufs        = make([][]byte, batchSize)
   225  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
   226  		count       = 0
   227  		sizes       = make([]int, batchSize)
   228  		offset      = MessageTransportHeaderSize
   229  	)
   230  
   231  	for i := range elems {
   232  		elems[i] = device.NewOutboundElement()
   233  		bufs[i] = elems[i].buffer[:]
   234  	}
   235  
   236  	defer func() {
   237  		for _, elem := range elems {
   238  			if elem != nil {
   239  				device.PutMessageBuffer(elem.buffer)
   240  				device.PutOutboundElement(elem)
   241  			}
   242  		}
   243  	}()
   244  
   245  	for {
   246  		// read packets
   247  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
   248  		for i := 0; i < count; i++ {
   249  			if sizes[i] < 1 {
   250  				continue
   251  			}
   252  
   253  			elem := elems[i]
   254  			elem.packet = bufs[i][offset : offset+sizes[i]]
   255  
   256  			// lookup peer
   257  			var peer *Peer
   258  			switch elem.packet[0] >> 4 {
   259  			case 4:
   260  				if len(elem.packet) < ipv4.HeaderLen {
   261  					continue
   262  				}
   263  				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   264  				if ifBindInterface && zkCli != nil && interfaceIP != "" && src[0] == 169 && src[1] == 254 && src[2] == interfaceIndex {
   265  					zkCli.RLock()
   266  					src[0] = interfaceIPArr[src[3]-1][0]
   267  					src[1] = interfaceIPArr[src[3]-1][1]
   268  					src[2] = interfaceIPArr[src[3]-1][2]
   269  					src[3] = interfaceIPArr[src[3]-1][3]
   270  					elem.packet[10] = 0
   271  					elem.packet[11] = 0
   272  					checksum := IPv4CheckSum(elem.packet[0:20])
   273  					binary.BigEndian.PutUint16(elem.packet[10:12], checksum)
   274  					zkCli.RUnlock()
   275  					//TODO: change TCP & UDP checksum
   276  					protoType := elem.packet[IPv4offsetProtoType]
   277  					switch protoType {
   278  					case 6:
   279  						TCPv4CheckSum(elem.packet)
   280  					case 17:
   281  						UDPv4CheckSum(elem.packet)
   282  					}
   283  				}
   284  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   285  				peer = device.allowedips.Lookup(dst)
   286  
   287  			case 6:
   288  				if len(elem.packet) < ipv6.HeaderLen {
   289  					continue
   290  				}
   291  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   292  				peer = device.allowedips.Lookup(dst)
   293  
   294  			default:
   295  				device.log.Verbosef("Received packet with unknown IP version")
   296  			}
   297  
   298  			if peer == nil {
   299  				continue
   300  			}
   301  			elemsForPeer, ok := elemsByPeer[peer]
   302  			if !ok {
   303  				elemsForPeer = device.GetOutboundElementsContainer()
   304  				elemsByPeer[peer] = elemsForPeer
   305  			}
   306  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
   307  			elems[i] = device.NewOutboundElement()
   308  			bufs[i] = elems[i].buffer[:]
   309  		}
   310  
   311  		for peer, elemsForPeer := range elemsByPeer {
   312  			if peer.isRunning.Load() {
   313  				peer.StagePackets(elemsForPeer)
   314  				peer.SendStagedPackets()
   315  			} else {
   316  				for _, elem := range elemsForPeer.elems {
   317  					device.PutMessageBuffer(elem.buffer)
   318  					device.PutOutboundElement(elem)
   319  				}
   320  				device.PutOutboundElementsContainer(elemsForPeer)
   321  			}
   322  			delete(elemsByPeer, peer)
   323  		}
   324  
   325  		if readErr != nil {
   326  			if errors.Is(readErr, tun.ErrTooManySegments) {
   327  				// TODO: record stat for this
   328  				// This will happen if MSS is surprisingly small (< 576)
   329  				// coincident with reasonably high throughput.
   330  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
   331  				continue
   332  			}
   333  			if !device.isClosed() {
   334  				if !errors.Is(readErr, os.ErrClosed) {
   335  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
   336  				}
   337  				go device.Close()
   338  			}
   339  			return
   340  		}
   341  	}
   342  }
   343  
   344  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
   345  	for {
   346  		select {
   347  		case peer.queue.staged <- elems:
   348  			return
   349  		default:
   350  		}
   351  		select {
   352  		case tooOld := <-peer.queue.staged:
   353  			for _, elem := range tooOld.elems {
   354  				peer.device.PutMessageBuffer(elem.buffer)
   355  				peer.device.PutOutboundElement(elem)
   356  			}
   357  			peer.device.PutOutboundElementsContainer(tooOld)
   358  		default:
   359  		}
   360  	}
   361  }
   362  
   363  func (peer *Peer) SendStagedPackets() {
   364  top:
   365  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   366  		return
   367  	}
   368  
   369  	keypair := peer.keypairs.Current()
   370  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   371  		peer.SendHandshakeInitiation(false)
   372  		return
   373  	}
   374  
   375  	for {
   376  		var elemsContainerOOO *QueueOutboundElementsContainer
   377  		select {
   378  		case elemsContainer := <-peer.queue.staged:
   379  			i := 0
   380  			for _, elem := range elemsContainer.elems {
   381  				elem.peer = peer
   382  				elem.nonce = keypair.sendNonce.Add(1) - 1
   383  				if elem.nonce >= RejectAfterMessages {
   384  					keypair.sendNonce.Store(RejectAfterMessages)
   385  					if elemsContainerOOO == nil {
   386  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
   387  					}
   388  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
   389  					continue
   390  				} else {
   391  					elemsContainer.elems[i] = elem
   392  					i++
   393  				}
   394  
   395  				elem.keypair = keypair
   396  			}
   397  			elemsContainer.Lock()
   398  			elemsContainer.elems = elemsContainer.elems[:i]
   399  
   400  			if elemsContainerOOO != nil {
   401  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
   402  			}
   403  
   404  			if len(elemsContainer.elems) == 0 {
   405  				peer.device.PutOutboundElementsContainer(elemsContainer)
   406  				goto top
   407  			}
   408  
   409  			// add to parallel and sequential queue
   410  			if peer.isRunning.Load() {
   411  				peer.queue.outbound.c <- elemsContainer
   412  				peer.device.queue.encryption.c <- elemsContainer
   413  			} else {
   414  				for _, elem := range elemsContainer.elems {
   415  					peer.device.PutMessageBuffer(elem.buffer)
   416  					peer.device.PutOutboundElement(elem)
   417  				}
   418  				peer.device.PutOutboundElementsContainer(elemsContainer)
   419  			}
   420  
   421  			if elemsContainerOOO != nil {
   422  				goto top
   423  			}
   424  		default:
   425  			return
   426  		}
   427  	}
   428  }
   429  
   430  func (peer *Peer) FlushStagedPackets() {
   431  	for {
   432  		select {
   433  		case elemsContainer := <-peer.queue.staged:
   434  			for _, elem := range elemsContainer.elems {
   435  				peer.device.PutMessageBuffer(elem.buffer)
   436  				peer.device.PutOutboundElement(elem)
   437  			}
   438  			peer.device.PutOutboundElementsContainer(elemsContainer)
   439  		default:
   440  			return
   441  		}
   442  	}
   443  }
   444  
   445  func calculatePaddingSize(packetSize, mtu int) int {
   446  	lastUnit := packetSize
   447  	if mtu == 0 {
   448  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   449  	}
   450  	if lastUnit > mtu {
   451  		lastUnit %= mtu
   452  	}
   453  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   454  	if paddedSize > mtu {
   455  		paddedSize = mtu
   456  	}
   457  	return paddedSize - lastUnit
   458  }
   459  
   460  /* Encrypts the elements in the queue
   461   * and marks them for sequential consumption (by releasing the mutex)
   462   *
   463   * Obs. One instance per core
   464   */
   465  func (device *Device) RoutineEncryption(id int) {
   466  	var paddingZeros [PaddingMultiple]byte
   467  	var nonce [chacha20poly1305.NonceSize]byte
   468  
   469  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   470  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   471  
   472  	for elemsContainer := range device.queue.encryption.c {
   473  		for _, elem := range elemsContainer.elems {
   474  			// populate header fields
   475  			header := elem.buffer[:MessageTransportHeaderSize]
   476  
   477  			fieldType := header[0:4]
   478  			fieldReceiver := header[4:8]
   479  			fieldNonce := header[8:16]
   480  
   481  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   482  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   483  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   484  
   485  			// pad content to multiple of 16
   486  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
   487  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   488  
   489  			// encrypt content and release to consumer
   490  
   491  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   492  			elem.packet = elem.keypair.send.Seal(
   493  				header,
   494  				nonce[:],
   495  				elem.packet,
   496  				nil,
   497  			)
   498  		}
   499  		elemsContainer.Unlock()
   500  	}
   501  }
   502  
   503  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
   504  	device := peer.device
   505  	defer func() {
   506  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   507  		peer.stopping.Done()
   508  	}()
   509  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   510  
   511  	bufs := make([][]byte, 0, maxBatchSize)
   512  
   513  	for elemsContainer := range peer.queue.outbound.c {
   514  		bufs = bufs[:0]
   515  		if elemsContainer == nil {
   516  			return
   517  		}
   518  		if !peer.isRunning.Load() {
   519  			// peer has been stopped; return re-usable elems to the shared pool.
   520  			// This is an optimization only. It is possible for the peer to be stopped
   521  			// immediately after this check, in which case, elem will get processed.
   522  			// The timers and SendBuffers code are resilient to a few stragglers.
   523  			// TODO: rework peer shutdown order to ensure
   524  			// that we never accidentally keep timers alive longer than necessary.
   525  			elemsContainer.Lock()
   526  			for _, elem := range elemsContainer.elems {
   527  				device.PutMessageBuffer(elem.buffer)
   528  				device.PutOutboundElement(elem)
   529  			}
   530  			continue
   531  		}
   532  		dataSent := false
   533  		elemsContainer.Lock()
   534  		for _, elem := range elemsContainer.elems {
   535  			if len(elem.packet) != MessageKeepaliveSize {
   536  				dataSent = true
   537  			}
   538  			bufs = append(bufs, elem.packet)
   539  		}
   540  
   541  		peer.timersAnyAuthenticatedPacketTraversal()
   542  		peer.timersAnyAuthenticatedPacketSent()
   543  
   544  		err := peer.SendBuffers(bufs)
   545  		if dataSent {
   546  			peer.timersDataSent()
   547  		}
   548  		for _, elem := range elemsContainer.elems {
   549  			device.PutMessageBuffer(elem.buffer)
   550  			device.PutOutboundElement(elem)
   551  		}
   552  		device.PutOutboundElementsContainer(elemsContainer)
   553  		if err != nil {
   554  			var errGSO conn.ErrUDPGSODisabled
   555  			if errors.As(err, &errGSO) {
   556  				device.log.Verbosef(err.Error())
   557  				err = errGSO.RetryErr
   558  			}
   559  		}
   560  		if err != nil {
   561  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
   562  			continue
   563  		}
   564  
   565  		peer.keepKeyFreshSending()
   566  	}
   567  }