github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/amnezia-vpn/amneziawg-go/conn"
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  )
    21  
    22  type QueueHandshakeElement struct {
    23  	msgType  uint32
    24  	packet   []byte
    25  	endpoint conn.Endpoint
    26  	buffer   *[MaxMessageSize]byte
    27  }
    28  
    29  type QueueInboundElement struct {
    30  	buffer   *[MaxMessageSize]byte
    31  	packet   []byte
    32  	counter  uint64
    33  	keypair  *Keypair
    34  	endpoint conn.Endpoint
    35  }
    36  
    37  type QueueInboundElementsContainer struct {
    38  	sync.Mutex
    39  	elems []*QueueInboundElement
    40  }
    41  
    42  // clearPointers clears elem fields that contain pointers.
    43  // This makes the garbage collector's life easier and
    44  // avoids accidentally keeping other objects around unnecessarily.
    45  // It also reduces the possible collateral damage from use-after-free bugs.
    46  func (elem *QueueInboundElement) clearPointers() {
    47  	elem.buffer = nil
    48  	elem.packet = nil
    49  	elem.keypair = nil
    50  	elem.endpoint = nil
    51  }
    52  
    53  /* Called when a new authenticated message has been received
    54   *
    55   * NOTE: Not thread safe, but called by sequential receiver!
    56   */
    57  func (peer *Peer) keepKeyFreshReceiving() {
    58  	if peer.timers.sentLastMinuteHandshake.Load() {
    59  		return
    60  	}
    61  	keypair := peer.keypairs.Current()
    62  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    63  		peer.timers.sentLastMinuteHandshake.Store(true)
    64  		peer.SendHandshakeInitiation(false)
    65  	}
    66  }
    67  
    68  /* Receives incoming datagrams for the device
    69   *
    70   * Every time the bind is updated a new routine is started for
    71   * IPv4 and IPv6 (separately)
    72   */
    73  func (device *Device) RoutineReceiveIncoming(
    74  	maxBatchSize int,
    75  	recv conn.ReceiveFunc,
    76  ) {
    77  	recvName := recv.PrettyName()
    78  	defer func() {
    79  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    80  		device.queue.decryption.wg.Done()
    81  		device.queue.handshake.wg.Done()
    82  		device.net.stopping.Done()
    83  	}()
    84  
    85  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    86  
    87  	// receive datagrams until conn is closed
    88  
    89  	var (
    90  		bufsArrs    = make([]*[MaxMessageSize]byte, maxBatchSize)
    91  		bufs        = make([][]byte, maxBatchSize)
    92  		err         error
    93  		sizes       = make([]int, maxBatchSize)
    94  		count       int
    95  		endpoints   = make([]conn.Endpoint, maxBatchSize)
    96  		deathSpiral int
    97  		elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
    98  	)
    99  
   100  	for i := range bufsArrs {
   101  		bufsArrs[i] = device.GetMessageBuffer()
   102  		bufs[i] = bufsArrs[i][:]
   103  	}
   104  
   105  	defer func() {
   106  		for i := 0; i < maxBatchSize; i++ {
   107  			if bufsArrs[i] != nil {
   108  				device.PutMessageBuffer(bufsArrs[i])
   109  			}
   110  		}
   111  	}()
   112  
   113  	for {
   114  		count, err = recv(bufs, sizes, endpoints)
   115  		if err != nil {
   116  			if errors.Is(err, net.ErrClosed) {
   117  				return
   118  			}
   119  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   120  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   121  				return
   122  			}
   123  			if deathSpiral < 10 {
   124  				deathSpiral++
   125  				time.Sleep(time.Second / 3)
   126  				continue
   127  			}
   128  			return
   129  		}
   130  		deathSpiral = 0
   131  
   132  		device.aSecMux.RLock()
   133  		// handle each packet in the batch
   134  		for i, size := range sizes[:count] {
   135  			if size < MinMessageSize {
   136  				continue
   137  			}
   138  
   139  			// check size of packet
   140  
   141  			packet := bufsArrs[i][:size]
   142  			var msgType uint32
   143  			if device.isAdvancedSecurityOn() {
   144  				if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
   145  					junkSize := msgTypeToJunkSize[assumedMsgType]
   146  					// transport size can align with other header types;
   147  					// making sure we have the right msgType
   148  					msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
   149  					if msgType == assumedMsgType {
   150  						packet = packet[junkSize:]
   151  					} else {
   152  						device.log.Verbosef("Transport packet lined up with another msg type")
   153  						msgType = binary.LittleEndian.Uint32(packet[:4])
   154  					}
   155  				} else {
   156  					msgType = binary.LittleEndian.Uint32(packet[:4])
   157  					if msgType != MessageTransportType {
   158  						device.log.Verbosef("ASec: Received message with unknown type")
   159  						continue
   160  					}
   161  				}
   162  			} else {
   163  				msgType = binary.LittleEndian.Uint32(packet[:4])
   164  			}
   165  			switch msgType {
   166  
   167  			// check if transport
   168  
   169  			case MessageTransportType:
   170  
   171  				// check size
   172  
   173  				if len(packet) < MessageTransportSize {
   174  					continue
   175  				}
   176  
   177  				// lookup key pair
   178  
   179  				receiver := binary.LittleEndian.Uint32(
   180  					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   181  				)
   182  				value := device.indexTable.Lookup(receiver)
   183  				keypair := value.keypair
   184  				if keypair == nil {
   185  					continue
   186  				}
   187  
   188  				// check keypair expiry
   189  
   190  				if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   191  					continue
   192  				}
   193  
   194  				// create work element
   195  				peer := value.peer
   196  				elem := device.GetInboundElement()
   197  				elem.packet = packet
   198  				elem.buffer = bufsArrs[i]
   199  				elem.keypair = keypair
   200  				elem.endpoint = endpoints[i]
   201  				elem.counter = 0
   202  
   203  				elemsForPeer, ok := elemsByPeer[peer]
   204  				if !ok {
   205  					elemsForPeer = device.GetInboundElementsContainer()
   206  					elemsForPeer.Lock()
   207  					elemsByPeer[peer] = elemsForPeer
   208  				}
   209  				elemsForPeer.elems = append(elemsForPeer.elems, elem)
   210  				bufsArrs[i] = device.GetMessageBuffer()
   211  				bufs[i] = bufsArrs[i][:]
   212  				continue
   213  
   214  			// otherwise it is a fixed size & handshake related packet
   215  
   216  			case MessageInitiationType:
   217  				if len(packet) != MessageInitiationSize {
   218  					continue
   219  				}
   220  
   221  			case MessageResponseType:
   222  				if len(packet) != MessageResponseSize {
   223  					continue
   224  				}
   225  
   226  			case MessageCookieReplyType:
   227  				if len(packet) != MessageCookieReplySize {
   228  					continue
   229  				}
   230  
   231  			default:
   232  				device.log.Verbosef("Received message with unknown type")
   233  				continue
   234  			}
   235  
   236  			select {
   237  			case device.queue.handshake.c <- QueueHandshakeElement{
   238  				msgType:  msgType,
   239  				buffer:   bufsArrs[i],
   240  				packet:   packet,
   241  				endpoint: endpoints[i],
   242  			}:
   243  				bufsArrs[i] = device.GetMessageBuffer()
   244  				bufs[i] = bufsArrs[i][:]
   245  			default:
   246  			}
   247  		}
   248  		device.aSecMux.RUnlock()
   249  		for peer, elemsContainer := range elemsByPeer {
   250  			if peer.isRunning.Load() {
   251  				peer.queue.inbound.c <- elemsContainer
   252  				device.queue.decryption.c <- elemsContainer
   253  			} else {
   254  				for _, elem := range elemsContainer.elems {
   255  					device.PutMessageBuffer(elem.buffer)
   256  					device.PutInboundElement(elem)
   257  				}
   258  				device.PutInboundElementsContainer(elemsContainer)
   259  			}
   260  			delete(elemsByPeer, peer)
   261  		}
   262  	}
   263  }
   264  
   265  func (device *Device) RoutineDecryption(id int) {
   266  	var nonce [chacha20poly1305.NonceSize]byte
   267  
   268  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   269  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   270  
   271  	for elemsContainer := range device.queue.decryption.c {
   272  		for _, elem := range elemsContainer.elems {
   273  			// split message into fields
   274  			counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   275  			content := elem.packet[MessageTransportOffsetContent:]
   276  
   277  			// decrypt and release to consumer
   278  			var err error
   279  			elem.counter = binary.LittleEndian.Uint64(counter)
   280  			// copy counter to nonce
   281  			binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   282  			elem.packet, err = elem.keypair.receive.Open(
   283  				content[:0],
   284  				nonce[:],
   285  				content,
   286  				nil,
   287  			)
   288  			if err != nil {
   289  				elem.packet = nil
   290  			}
   291  		}
   292  		elemsContainer.Unlock()
   293  	}
   294  }
   295  
   296  /* Handles incoming packets related to handshake
   297   */
   298  func (device *Device) RoutineHandshake(id int) {
   299  	defer func() {
   300  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   301  		device.queue.encryption.wg.Done()
   302  	}()
   303  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   304  
   305  	for elem := range device.queue.handshake.c {
   306  
   307  		device.aSecMux.RLock()
   308  
   309  		// handle cookie fields and ratelimiting
   310  
   311  		switch elem.msgType {
   312  
   313  		case MessageCookieReplyType:
   314  
   315  			// unmarshal packet
   316  
   317  			var reply MessageCookieReply
   318  			reader := bytes.NewReader(elem.packet)
   319  			err := binary.Read(reader, binary.LittleEndian, &reply)
   320  			if err != nil {
   321  				device.log.Verbosef("Failed to decode cookie reply")
   322  				goto skip
   323  			}
   324  
   325  			// lookup peer from index
   326  
   327  			entry := device.indexTable.Lookup(reply.Receiver)
   328  
   329  			if entry.peer == nil {
   330  				goto skip
   331  			}
   332  
   333  			// consume reply
   334  
   335  			if peer := entry.peer; peer.isRunning.Load() {
   336  				device.log.Verbosef(
   337  					"Receiving cookie response from %s",
   338  					elem.endpoint.DstToString(),
   339  				)
   340  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   341  					device.log.Verbosef(
   342  						"Could not decrypt invalid cookie response",
   343  					)
   344  				}
   345  			}
   346  
   347  			goto skip
   348  
   349  		case MessageInitiationType, MessageResponseType:
   350  
   351  			// check mac fields and maybe ratelimit
   352  
   353  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   354  				device.log.Verbosef("Received packet with invalid mac1")
   355  				goto skip
   356  			}
   357  
   358  			// endpoints destination address is the source of the datagram
   359  
   360  			if device.IsUnderLoad() {
   361  
   362  				// verify MAC2 field
   363  
   364  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   365  					device.SendHandshakeCookie(&elem)
   366  					goto skip
   367  				}
   368  
   369  				// check ratelimiter
   370  
   371  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   372  					goto skip
   373  				}
   374  			}
   375  
   376  		default:
   377  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   378  			goto skip
   379  		}
   380  
   381  		// handle handshake initiation/response content
   382  
   383  		switch elem.msgType {
   384  		case MessageInitiationType:
   385  			// unmarshal
   386  			var msg MessageInitiation
   387  			reader := bytes.NewReader(elem.packet)
   388  			err := binary.Read(reader, binary.LittleEndian, &msg)
   389  			if err != nil {
   390  				device.log.Errorf("Failed to decode initiation message")
   391  				goto skip
   392  			}
   393  
   394  			// consume initiation
   395  			peer := device.ConsumeMessageInitiation(&msg)
   396  			if peer == nil {
   397  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   398  				goto skip
   399  			}
   400  
   401  			// update timers
   402  
   403  			peer.timersAnyAuthenticatedPacketTraversal()
   404  			peer.timersAnyAuthenticatedPacketReceived()
   405  
   406  			// update endpoint
   407  			peer.SetEndpointFromPacket(elem.endpoint)
   408  
   409  			device.log.Verbosef("%v - Received handshake initiation", peer)
   410  			peer.rxBytes.Add(uint64(len(elem.packet)))
   411  
   412  			peer.SendHandshakeResponse()
   413  
   414  		case MessageResponseType:
   415  
   416  			// unmarshal
   417  
   418  			var msg MessageResponse
   419  			reader := bytes.NewReader(elem.packet)
   420  			err := binary.Read(reader, binary.LittleEndian, &msg)
   421  			if err != nil {
   422  				device.log.Errorf("Failed to decode response message")
   423  				goto skip
   424  			}
   425  
   426  			// consume response
   427  
   428  			peer := device.ConsumeMessageResponse(&msg)
   429  			if peer == nil {
   430  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   431  				goto skip
   432  			}
   433  
   434  			// update endpoint
   435  			peer.SetEndpointFromPacket(elem.endpoint)
   436  
   437  			device.log.Verbosef("%v - Received handshake response", peer)
   438  			peer.rxBytes.Add(uint64(len(elem.packet)))
   439  
   440  			// update timers
   441  
   442  			peer.timersAnyAuthenticatedPacketTraversal()
   443  			peer.timersAnyAuthenticatedPacketReceived()
   444  
   445  			// derive keypair
   446  
   447  			err = peer.BeginSymmetricSession()
   448  
   449  			if err != nil {
   450  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   451  				goto skip
   452  			}
   453  
   454  			peer.timersSessionDerived()
   455  			peer.timersHandshakeComplete()
   456  			peer.SendKeepalive()
   457  		}
   458  	skip:
   459  		device.aSecMux.RUnlock()
   460  		device.PutMessageBuffer(elem.buffer)
   461  	}
   462  }
   463  
   464  func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
   465  	device := peer.device
   466  	defer func() {
   467  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   468  		peer.stopping.Done()
   469  	}()
   470  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   471  
   472  	bufs := make([][]byte, 0, maxBatchSize)
   473  
   474  	for elemsContainer := range peer.queue.inbound.c {
   475  		if elemsContainer == nil {
   476  			return
   477  		}
   478  		elemsContainer.Lock()
   479  		validTailPacket := -1
   480  		dataPacketReceived := false
   481  		rxBytesLen := uint64(0)
   482  		for i, elem := range elemsContainer.elems {
   483  			if elem.packet == nil {
   484  				// decryption failed
   485  				continue
   486  			}
   487  
   488  			if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   489  				continue
   490  			}
   491  
   492  			validTailPacket = i
   493  			if peer.ReceivedWithKeypair(elem.keypair) {
   494  				peer.SetEndpointFromPacket(elem.endpoint)
   495  				peer.timersHandshakeComplete()
   496  				peer.SendStagedPackets()
   497  			}
   498  			rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
   499  
   500  			if len(elem.packet) == 0 {
   501  				device.log.Verbosef("%v - Receiving keepalive packet", peer)
   502  				continue
   503  			}
   504  			dataPacketReceived = true
   505  
   506  			switch elem.packet[0] >> 4 {
   507  			case 4:
   508  				if len(elem.packet) < ipv4.HeaderLen {
   509  					continue
   510  				}
   511  				field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   512  				length := binary.BigEndian.Uint16(field)
   513  				if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   514  					continue
   515  				}
   516  				elem.packet = elem.packet[:length]
   517  				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   518  				if device.allowedips.Lookup(src) != peer {
   519  					device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   520  					continue
   521  				}
   522  
   523  			case 6:
   524  				if len(elem.packet) < ipv6.HeaderLen {
   525  					continue
   526  				}
   527  				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   528  				length := binary.BigEndian.Uint16(field)
   529  				length += ipv6.HeaderLen
   530  				if int(length) > len(elem.packet) {
   531  					continue
   532  				}
   533  				elem.packet = elem.packet[:length]
   534  				src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   535  				if device.allowedips.Lookup(src) != peer {
   536  					device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   537  					continue
   538  				}
   539  
   540  			default:
   541  				device.log.Verbosef(
   542  					"Packet with invalid IP version from %v",
   543  					peer,
   544  				)
   545  				continue
   546  			}
   547  
   548  			bufs = append(
   549  				bufs,
   550  				elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
   551  			)
   552  		}
   553  
   554  		peer.rxBytes.Add(rxBytesLen)
   555  		if validTailPacket >= 0 {
   556  			peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
   557  			peer.keepKeyFreshReceiving()
   558  			peer.timersAnyAuthenticatedPacketTraversal()
   559  			peer.timersAnyAuthenticatedPacketReceived()
   560  		}
   561  		if dataPacketReceived {
   562  			peer.timersDataReceived()
   563  		}
   564  		if len(bufs) > 0 {
   565  			_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
   566  			if err != nil && !device.isClosed() {
   567  				device.log.Errorf("Failed to write packets to TUN device: %v", err)
   568  			}
   569  		}
   570  		for _, elem := range elemsContainer.elems {
   571  			device.PutMessageBuffer(elem.buffer)
   572  			device.PutInboundElement(elem)
   573  		}
   574  		bufs = bufs[:0]
   575  		device.PutInboundElementsContainer(elemsContainer)
   576  	}
   577  }