github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/send.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"os"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"golang.org/x/crypto/chacha20poly1305"
    19  	"golang.org/x/net/ipv4"
    20  	"golang.org/x/net/ipv6"
    21  )
    22  
    23  /* Outbound flow
    24   *
    25   * 1. TUN queue
    26   * 2. Routing (sequential)
    27   * 3. Nonce assignment (sequential)
    28   * 4. Encryption (parallel)
    29   * 5. Transmission (sequential)
    30   *
    31   * The functions in this file occur (roughly) in the order in
    32   * which the packets are processed.
    33   *
    34   * Locking, Producers and Consumers
    35   *
    36   * The order of packets (per peer) must be maintained,
    37   * but encryption of packets happen out-of-order:
    38   *
    39   * The sequential consumers will attempt to take the lock,
    40   * workers release lock when they have completed work (encryption) on the packet.
    41   *
    42   * If the element is inserted into the "encryption queue",
    43   * the content is preceded by enough "junk" to contain the transport header
    44   * (to allow the construction of transport messages in-place)
    45   */
    46  
    47  type QueueOutboundElement struct {
    48  	sync.Mutex
    49  	buffer  *[MaxMessageSize]byte // slice holding the packet data
    50  	packet  []byte                // slice of "buffer" (always!)
    51  	nonce   uint64                // nonce for encryption
    52  	keypair *Keypair              // keypair for encryption
    53  	peer    *Peer                 // related peer
    54  }
    55  
    56  func (device *Device) NewOutboundElement() *QueueOutboundElement {
    57  	elem := device.GetOutboundElement()
    58  	elem.buffer = device.GetMessageBuffer()
    59  	elem.Mutex = sync.Mutex{}
    60  	elem.nonce = 0
    61  	// keypair and peer were cleared (if necessary) by clearPointers.
    62  	return elem
    63  }
    64  
    65  // clearPointers clears elem fields that contain pointers.
    66  // This makes the garbage collector's life easier and
    67  // avoids accidentally keeping other objects around unnecessarily.
    68  // It also reduces the possible collateral damage from use-after-free bugs.
    69  func (elem *QueueOutboundElement) clearPointers() {
    70  	elem.buffer = nil
    71  	elem.packet = nil
    72  	elem.keypair = nil
    73  	elem.peer = nil
    74  }
    75  
    76  /* Queues a keepalive if no packets are queued for peer
    77   */
    78  func (peer *Peer) SendKeepalive() {
    79  	if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
    80  		elem := peer.device.NewOutboundElement()
    81  		select {
    82  		case peer.queue.staged <- elem:
    83  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
    84  		default:
    85  			peer.device.PutMessageBuffer(elem.buffer)
    86  			peer.device.PutOutboundElement(elem)
    87  		}
    88  	}
    89  	peer.SendStagedPackets()
    90  }
    91  
    92  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
    93  	if !isRetry {
    94  		atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
    95  	}
    96  
    97  	peer.handshake.mutex.RLock()
    98  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
    99  		peer.handshake.mutex.RUnlock()
   100  		return nil
   101  	}
   102  	peer.handshake.mutex.RUnlock()
   103  
   104  	peer.handshake.mutex.Lock()
   105  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
   106  		peer.handshake.mutex.Unlock()
   107  		return nil
   108  	}
   109  	peer.handshake.lastSentHandshake = time.Now()
   110  	peer.handshake.mutex.Unlock()
   111  
   112  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
   113  
   114  	msg, err := peer.device.CreateMessageInitiation(peer)
   115  	if err != nil {
   116  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
   117  		return err
   118  	}
   119  
   120  	var buff [MessageInitiationSize]byte
   121  	writer := bytes.NewBuffer(buff[:0])
   122  	binary.Write(writer, binary.LittleEndian, msg)
   123  	packet := writer.Bytes()
   124  	peer.cookieGenerator.AddMacs(packet)
   125  
   126  	peer.timersAnyAuthenticatedPacketTraversal()
   127  	peer.timersAnyAuthenticatedPacketSent()
   128  
   129  	err = peer.SendBuffer(packet)
   130  	if err != nil {
   131  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
   132  	}
   133  	peer.timersHandshakeInitiated()
   134  
   135  	return err
   136  }
   137  
   138  func (peer *Peer) SendHandshakeResponse() error {
   139  	peer.handshake.mutex.Lock()
   140  	peer.handshake.lastSentHandshake = time.Now()
   141  	peer.handshake.mutex.Unlock()
   142  
   143  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
   144  
   145  	response, err := peer.device.CreateMessageResponse(peer)
   146  	if err != nil {
   147  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
   148  		return err
   149  	}
   150  
   151  	var buff [MessageResponseSize]byte
   152  	writer := bytes.NewBuffer(buff[:0])
   153  	binary.Write(writer, binary.LittleEndian, response)
   154  	packet := writer.Bytes()
   155  	peer.cookieGenerator.AddMacs(packet)
   156  
   157  	err = peer.BeginSymmetricSession()
   158  	if err != nil {
   159  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   160  		return err
   161  	}
   162  
   163  	peer.timersSessionDerived()
   164  	peer.timersAnyAuthenticatedPacketTraversal()
   165  	peer.timersAnyAuthenticatedPacketSent()
   166  
   167  	err = peer.SendBuffer(packet)
   168  	if err != nil {
   169  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
   170  	}
   171  	return err
   172  }
   173  
   174  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
   175  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
   176  
   177  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
   178  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
   179  	if err != nil {
   180  		device.log.Errorf("Failed to create cookie reply: %v", err)
   181  		return err
   182  	}
   183  
   184  	var buff [MessageCookieReplySize]byte
   185  	writer := bytes.NewBuffer(buff[:0])
   186  	binary.Write(writer, binary.LittleEndian, reply)
   187  	device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
   188  	return nil
   189  }
   190  
   191  func (peer *Peer) keepKeyFreshSending() {
   192  	keypair := peer.keypairs.Current()
   193  	if keypair == nil {
   194  		return
   195  	}
   196  	nonce := atomic.LoadUint64(&keypair.sendNonce)
   197  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
   198  		peer.SendHandshakeInitiation(false)
   199  	}
   200  }
   201  
   202  /* Reads packets from the TUN and inserts
   203   * into staged queue for peer
   204   *
   205   * Obs. Single instance per TUN device
   206   */
   207  func (device *Device) RoutineReadFromTUN() {
   208  	defer func() {
   209  		device.log.Verbosef("Routine: TUN reader - stopped")
   210  		device.state.stopping.Done()
   211  		device.queue.encryption.wg.Done()
   212  	}()
   213  
   214  	device.log.Verbosef("Routine: TUN reader - started")
   215  
   216  	var elem *QueueOutboundElement
   217  
   218  	for {
   219  		if elem != nil {
   220  			device.PutMessageBuffer(elem.buffer)
   221  			device.PutOutboundElement(elem)
   222  		}
   223  		elem = device.NewOutboundElement()
   224  
   225  		// read packet
   226  
   227  		offset := MessageTransportHeaderSize
   228  		size, err := device.tun.device.Read(elem.buffer[:], offset)
   229  
   230  		if err != nil {
   231  			if !device.isClosed() {
   232  				if !errors.Is(err, os.ErrClosed) {
   233  					device.log.Errorf("Failed to read packet from TUN device: %v", err)
   234  				}
   235  				go device.Close()
   236  			}
   237  			device.PutMessageBuffer(elem.buffer)
   238  			device.PutOutboundElement(elem)
   239  			return
   240  		}
   241  
   242  		if size == 0 || size > MaxContentSize {
   243  			continue
   244  		}
   245  
   246  		elem.packet = elem.buffer[offset : offset+size]
   247  
   248  		// lookup peer
   249  
   250  		var peer *Peer
   251  		switch elem.packet[0] >> 4 {
   252  		case ipv4.Version:
   253  			if len(elem.packet) < ipv4.HeaderLen {
   254  				continue
   255  			}
   256  			dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
   257  			peer = device.allowedips.LookupIPv4(dst)
   258  
   259  		case ipv6.Version:
   260  			if len(elem.packet) < ipv6.HeaderLen {
   261  				continue
   262  			}
   263  			dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
   264  			peer = device.allowedips.LookupIPv6(dst)
   265  
   266  		default:
   267  			device.log.Verbosef("Received packet with unknown IP version")
   268  		}
   269  
   270  		if peer == nil {
   271  			continue
   272  		}
   273  		if peer.isRunning.Get() {
   274  			peer.StagePacket(elem)
   275  			elem = nil
   276  			peer.SendStagedPackets()
   277  		}
   278  	}
   279  }
   280  
   281  func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
   282  	for {
   283  		select {
   284  		case peer.queue.staged <- elem:
   285  			return
   286  		default:
   287  		}
   288  		select {
   289  		case tooOld := <-peer.queue.staged:
   290  			peer.device.PutMessageBuffer(tooOld.buffer)
   291  			peer.device.PutOutboundElement(tooOld)
   292  		default:
   293  		}
   294  	}
   295  }
   296  
   297  func (peer *Peer) SendStagedPackets() {
   298  top:
   299  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
   300  		return
   301  	}
   302  
   303  	keypair := peer.keypairs.Current()
   304  	if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
   305  		peer.SendHandshakeInitiation(false)
   306  		return
   307  	}
   308  
   309  	for {
   310  		select {
   311  		case elem := <-peer.queue.staged:
   312  			elem.peer = peer
   313  			elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
   314  			if elem.nonce >= RejectAfterMessages {
   315  				atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
   316  				peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
   317  				goto top
   318  			}
   319  
   320  			elem.keypair = keypair
   321  			elem.Lock()
   322  
   323  			// add to parallel and sequential queue
   324  			if peer.isRunning.Get() {
   325  				peer.queue.outbound.c <- elem
   326  				peer.device.queue.encryption.c <- elem
   327  			} else {
   328  				peer.device.PutMessageBuffer(elem.buffer)
   329  				peer.device.PutOutboundElement(elem)
   330  			}
   331  		default:
   332  			return
   333  		}
   334  	}
   335  }
   336  
   337  func (peer *Peer) FlushStagedPackets() {
   338  	for {
   339  		select {
   340  		case elem := <-peer.queue.staged:
   341  			peer.device.PutMessageBuffer(elem.buffer)
   342  			peer.device.PutOutboundElement(elem)
   343  		default:
   344  			return
   345  		}
   346  	}
   347  }
   348  
   349  func calculatePaddingSize(packetSize, mtu int) int {
   350  	lastUnit := packetSize
   351  	if mtu == 0 {
   352  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
   353  	}
   354  	if lastUnit > mtu {
   355  		lastUnit %= mtu
   356  	}
   357  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
   358  	if paddedSize > mtu {
   359  		paddedSize = mtu
   360  	}
   361  	return paddedSize - lastUnit
   362  }
   363  
   364  /* Encrypts the elements in the queue
   365   * and marks them for sequential consumption (by releasing the mutex)
   366   *
   367   * Obs. One instance per core
   368   */
   369  func (device *Device) RoutineEncryption(id int) {
   370  	var paddingZeros [PaddingMultiple]byte
   371  	var nonce [chacha20poly1305.NonceSize]byte
   372  
   373  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
   374  	device.log.Verbosef("Routine: encryption worker %d - started", id)
   375  
   376  	for elem := range device.queue.encryption.c {
   377  		// populate header fields
   378  		header := elem.buffer[:MessageTransportHeaderSize]
   379  
   380  		fieldType := header[0:4]
   381  		fieldReceiver := header[4:8]
   382  		fieldNonce := header[8:16]
   383  
   384  		binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
   385  		binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
   386  		binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
   387  
   388  		// pad content to multiple of 16
   389  		paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
   390  		elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
   391  
   392  		// encrypt content and release to consumer
   393  
   394  		binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
   395  		elem.packet = elem.keypair.send.Seal(
   396  			header,
   397  			nonce[:],
   398  			elem.packet,
   399  			nil,
   400  		)
   401  		elem.Unlock()
   402  	}
   403  }
   404  
   405  /* Sequentially reads packets from queue and sends to endpoint
   406   *
   407   * Obs. Single instance per peer.
   408   * The routine terminates then the outbound queue is closed.
   409   */
   410  func (peer *Peer) RoutineSequentialSender() {
   411  	device := peer.device
   412  	defer func() {
   413  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
   414  		peer.stopping.Done()
   415  	}()
   416  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
   417  
   418  	for elem := range peer.queue.outbound.c {
   419  		if elem == nil {
   420  			return
   421  		}
   422  		elem.Lock()
   423  		if !peer.isRunning.Get() {
   424  			// peer has been stopped; return re-usable elems to the shared pool.
   425  			// This is an optimization only. It is possible for the peer to be stopped
   426  			// immediately after this check, in which case, elem will get processed.
   427  			// The timers and SendBuffer code are resilient to a few stragglers.
   428  			// TODO: rework peer shutdown order to ensure
   429  			// that we never accidentally keep timers alive longer than necessary.
   430  			device.PutMessageBuffer(elem.buffer)
   431  			device.PutOutboundElement(elem)
   432  			continue
   433  		}
   434  
   435  		peer.timersAnyAuthenticatedPacketTraversal()
   436  		peer.timersAnyAuthenticatedPacketSent()
   437  
   438  		// send message and return buffer to pool
   439  
   440  		err := peer.SendBuffer(elem.packet)
   441  		if len(elem.packet) != MessageKeepaliveSize {
   442  			peer.timersDataSent()
   443  		}
   444  		device.PutMessageBuffer(elem.buffer)
   445  		device.PutOutboundElement(elem)
   446  		if err != nil {
   447  			device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
   448  			continue
   449  		}
   450  
   451  		peer.keepKeyFreshSending()
   452  	}
   453  }