github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"time"
    15  
    16  	"golang.org/x/crypto/chacha20poly1305"
    17  	"golang.org/x/net/ipv4"
    18  	"golang.org/x/net/ipv6"
    19  	"github.com/koomox/wireguard-go/conn"
    20  )
    21  
    22  type QueueHandshakeElement struct {
    23  	msgType  uint32
    24  	packet   []byte
    25  	endpoint conn.Endpoint
    26  	buffer   *[MaxMessageSize]byte
    27  }
    28  
    29  type QueueInboundElement struct {
    30  	sync.Mutex
    31  	buffer   *[MaxMessageSize]byte
    32  	packet   []byte
    33  	counter  uint64
    34  	keypair  *Keypair
    35  	endpoint conn.Endpoint
    36  }
    37  
    38  // clearPointers clears elem fields that contain pointers.
    39  // This makes the garbage collector's life easier and
    40  // avoids accidentally keeping other objects around unnecessarily.
    41  // It also reduces the possible collateral damage from use-after-free bugs.
    42  func (elem *QueueInboundElement) clearPointers() {
    43  	elem.buffer = nil
    44  	elem.packet = nil
    45  	elem.keypair = nil
    46  	elem.endpoint = nil
    47  }
    48  
    49  /* Called when a new authenticated message has been received
    50   *
    51   * NOTE: Not thread safe, but called by sequential receiver!
    52   */
    53  func (peer *Peer) keepKeyFreshReceiving() {
    54  	if peer.timers.sentLastMinuteHandshake.Load() {
    55  		return
    56  	}
    57  	keypair := peer.keypairs.Current()
    58  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    59  		peer.timers.sentLastMinuteHandshake.Store(true)
    60  		peer.SendHandshakeInitiation(false)
    61  	}
    62  }
    63  
    64  /* Receives incoming datagrams for the device
    65   *
    66   * Every time the bind is updated a new routine is started for
    67   * IPv4 and IPv6 (separately)
    68   */
    69  func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
    70  	recvName := recv.PrettyName()
    71  	defer func() {
    72  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    73  		device.queue.decryption.wg.Done()
    74  		device.queue.handshake.wg.Done()
    75  		device.net.stopping.Done()
    76  	}()
    77  
    78  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    79  
    80  	// receive datagrams until conn is closed
    81  
    82  	var (
    83  		bufsArrs    = make([]*[MaxMessageSize]byte, maxBatchSize)
    84  		bufs        = make([][]byte, maxBatchSize)
    85  		err         error
    86  		sizes       = make([]int, maxBatchSize)
    87  		count       int
    88  		endpoints   = make([]conn.Endpoint, maxBatchSize)
    89  		deathSpiral int
    90  		elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
    91  	)
    92  
    93  	for i := range bufsArrs {
    94  		bufsArrs[i] = device.GetMessageBuffer()
    95  		bufs[i] = bufsArrs[i][:]
    96  	}
    97  
    98  	defer func() {
    99  		for i := 0; i < maxBatchSize; i++ {
   100  			if bufsArrs[i] != nil {
   101  				device.PutMessageBuffer(bufsArrs[i])
   102  			}
   103  		}
   104  	}()
   105  
   106  	for {
   107  		count, err = recv(bufs, sizes, endpoints)
   108  		if err != nil {
   109  			if errors.Is(err, net.ErrClosed) {
   110  				return
   111  			}
   112  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   113  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   114  				return
   115  			}
   116  			if deathSpiral < 10 {
   117  				deathSpiral++
   118  				time.Sleep(time.Second / 3)
   119  				continue
   120  			}
   121  			return
   122  		}
   123  		deathSpiral = 0
   124  
   125  		// handle each packet in the batch
   126  		for i, size := range sizes[:count] {
   127  			if size < MinMessageSize {
   128  				continue
   129  			}
   130  
   131  			// check size of packet
   132  
   133  			packet := bufsArrs[i][:size]
   134  			msgType := binary.LittleEndian.Uint32(packet[:4])
   135  
   136  			switch msgType {
   137  
   138  			// check if transport
   139  
   140  			case MessageTransportType:
   141  
   142  				// check size
   143  
   144  				if len(packet) < MessageTransportSize {
   145  					continue
   146  				}
   147  
   148  				// lookup key pair
   149  
   150  				receiver := binary.LittleEndian.Uint32(
   151  					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   152  				)
   153  				value := device.indexTable.Lookup(receiver)
   154  				keypair := value.keypair
   155  				if keypair == nil {
   156  					continue
   157  				}
   158  
   159  				// check keypair expiry
   160  
   161  				if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   162  					continue
   163  				}
   164  
   165  				// create work element
   166  				peer := value.peer
   167  				elem := device.GetInboundElement()
   168  				elem.packet = packet
   169  				elem.buffer = bufsArrs[i]
   170  				elem.keypair = keypair
   171  				elem.endpoint = endpoints[i]
   172  				elem.counter = 0
   173  				elem.Mutex = sync.Mutex{}
   174  				elem.Lock()
   175  
   176  				elemsForPeer, ok := elemsByPeer[peer]
   177  				if !ok {
   178  					elemsForPeer = device.GetInboundElementsSlice()
   179  					elemsByPeer[peer] = elemsForPeer
   180  				}
   181  				*elemsForPeer = append(*elemsForPeer, elem)
   182  				bufsArrs[i] = device.GetMessageBuffer()
   183  				bufs[i] = bufsArrs[i][:]
   184  				continue
   185  
   186  			// otherwise it is a fixed size & handshake related packet
   187  
   188  			case MessageInitiationType:
   189  				if len(packet) != MessageInitiationSize {
   190  					continue
   191  				}
   192  
   193  			case MessageResponseType:
   194  				if len(packet) != MessageResponseSize {
   195  					continue
   196  				}
   197  
   198  			case MessageCookieReplyType:
   199  				if len(packet) != MessageCookieReplySize {
   200  					continue
   201  				}
   202  
   203  			default:
   204  				device.log.Verbosef("Received message with unknown type")
   205  				continue
   206  			}
   207  
   208  			select {
   209  			case device.queue.handshake.c <- QueueHandshakeElement{
   210  				msgType:  msgType,
   211  				buffer:   bufsArrs[i],
   212  				packet:   packet,
   213  				endpoint: endpoints[i],
   214  			}:
   215  				bufsArrs[i] = device.GetMessageBuffer()
   216  				bufs[i] = bufsArrs[i][:]
   217  			default:
   218  			}
   219  		}
   220  		for peer, elems := range elemsByPeer {
   221  			if peer.isRunning.Load() {
   222  				peer.queue.inbound.c <- elems
   223  				for _, elem := range *elems {
   224  					device.queue.decryption.c <- elem
   225  				}
   226  			} else {
   227  				for _, elem := range *elems {
   228  					device.PutMessageBuffer(elem.buffer)
   229  					device.PutInboundElement(elem)
   230  				}
   231  				device.PutInboundElementsSlice(elems)
   232  			}
   233  			delete(elemsByPeer, peer)
   234  		}
   235  	}
   236  }
   237  
   238  func (device *Device) RoutineDecryption(id int) {
   239  	var nonce [chacha20poly1305.NonceSize]byte
   240  
   241  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   242  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   243  
   244  	for elem := range device.queue.decryption.c {
   245  		// split message into fields
   246  		counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   247  		content := elem.packet[MessageTransportOffsetContent:]
   248  
   249  		// decrypt and release to consumer
   250  		var err error
   251  		elem.counter = binary.LittleEndian.Uint64(counter)
   252  		// copy counter to nonce
   253  		binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   254  		elem.packet, err = elem.keypair.receive.Open(
   255  			content[:0],
   256  			nonce[:],
   257  			content,
   258  			nil,
   259  		)
   260  		if err != nil {
   261  			elem.packet = nil
   262  		}
   263  		elem.Unlock()
   264  	}
   265  }
   266  
   267  /* Handles incoming packets related to handshake
   268   */
   269  func (device *Device) RoutineHandshake(id int) {
   270  	defer func() {
   271  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   272  		device.queue.encryption.wg.Done()
   273  	}()
   274  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   275  
   276  	for elem := range device.queue.handshake.c {
   277  
   278  		// handle cookie fields and ratelimiting
   279  
   280  		switch elem.msgType {
   281  
   282  		case MessageCookieReplyType:
   283  
   284  			// unmarshal packet
   285  
   286  			var reply MessageCookieReply
   287  			reader := bytes.NewReader(elem.packet)
   288  			err := binary.Read(reader, binary.LittleEndian, &reply)
   289  			if err != nil {
   290  				device.log.Verbosef("Failed to decode cookie reply")
   291  				goto skip
   292  			}
   293  
   294  			// lookup peer from index
   295  
   296  			entry := device.indexTable.Lookup(reply.Receiver)
   297  
   298  			if entry.peer == nil {
   299  				goto skip
   300  			}
   301  
   302  			// consume reply
   303  
   304  			if peer := entry.peer; peer.isRunning.Load() {
   305  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
   306  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   307  					device.log.Verbosef("Could not decrypt invalid cookie response")
   308  				}
   309  			}
   310  
   311  			goto skip
   312  
   313  		case MessageInitiationType, MessageResponseType:
   314  
   315  			// check mac fields and maybe ratelimit
   316  
   317  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   318  				device.log.Verbosef("Received packet with invalid mac1")
   319  				goto skip
   320  			}
   321  
   322  			// endpoints destination address is the source of the datagram
   323  
   324  			if device.IsUnderLoad() {
   325  
   326  				// verify MAC2 field
   327  
   328  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   329  					device.SendHandshakeCookie(&elem)
   330  					goto skip
   331  				}
   332  
   333  				// check ratelimiter
   334  
   335  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   336  					goto skip
   337  				}
   338  			}
   339  
   340  		default:
   341  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   342  			goto skip
   343  		}
   344  
   345  		// handle handshake initiation/response content
   346  
   347  		switch elem.msgType {
   348  		case MessageInitiationType:
   349  
   350  			// unmarshal
   351  
   352  			var msg MessageInitiation
   353  			reader := bytes.NewReader(elem.packet)
   354  			err := binary.Read(reader, binary.LittleEndian, &msg)
   355  			if err != nil {
   356  				device.log.Errorf("Failed to decode initiation message")
   357  				goto skip
   358  			}
   359  
   360  			// consume initiation
   361  
   362  			peer := device.ConsumeMessageInitiation(&msg)
   363  			if peer == nil {
   364  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   365  				goto skip
   366  			}
   367  
   368  			// update timers
   369  
   370  			peer.timersAnyAuthenticatedPacketTraversal()
   371  			peer.timersAnyAuthenticatedPacketReceived()
   372  
   373  			// update endpoint
   374  			peer.SetEndpointFromPacket(elem.endpoint)
   375  
   376  			device.log.Verbosef("%v - Received handshake initiation", peer)
   377  			peer.rxBytes.Add(uint64(len(elem.packet)))
   378  
   379  			peer.SendHandshakeResponse()
   380  
   381  		case MessageResponseType:
   382  
   383  			// unmarshal
   384  
   385  			var msg MessageResponse
   386  			reader := bytes.NewReader(elem.packet)
   387  			err := binary.Read(reader, binary.LittleEndian, &msg)
   388  			if err != nil {
   389  				device.log.Errorf("Failed to decode response message")
   390  				goto skip
   391  			}
   392  
   393  			// consume response
   394  
   395  			peer := device.ConsumeMessageResponse(&msg)
   396  			if peer == nil {
   397  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   398  				goto skip
   399  			}
   400  
   401  			// update endpoint
   402  			peer.SetEndpointFromPacket(elem.endpoint)
   403  
   404  			device.log.Verbosef("%v - Received handshake response", peer)
   405  			peer.rxBytes.Add(uint64(len(elem.packet)))
   406  
   407  			// update timers
   408  
   409  			peer.timersAnyAuthenticatedPacketTraversal()
   410  			peer.timersAnyAuthenticatedPacketReceived()
   411  
   412  			// derive keypair
   413  
   414  			err = peer.BeginSymmetricSession()
   415  
   416  			if err != nil {
   417  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   418  				goto skip
   419  			}
   420  
   421  			peer.timersSessionDerived()
   422  			peer.timersHandshakeComplete()
   423  			peer.SendKeepalive()
   424  		}
   425  	skip:
   426  		device.PutMessageBuffer(elem.buffer)
   427  	}
   428  }
   429  
   430  func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
   431  	device := peer.device
   432  	defer func() {
   433  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   434  		peer.stopping.Done()
   435  	}()
   436  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   437  
   438  	bufs := make([][]byte, 0, maxBatchSize)
   439  
   440  	for elems := range peer.queue.inbound.c {
   441  		if elems == nil {
   442  			return
   443  		}
   444  		for _, elem := range *elems {
   445  			elem.Lock()
   446  			if elem.packet == nil {
   447  				// decryption failed
   448  				continue
   449  			}
   450  
   451  			if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   452  				continue
   453  			}
   454  
   455  			peer.SetEndpointFromPacket(elem.endpoint)
   456  			if peer.ReceivedWithKeypair(elem.keypair) {
   457  				peer.timersHandshakeComplete()
   458  				peer.SendStagedPackets()
   459  			}
   460  			peer.keepKeyFreshReceiving()
   461  			peer.timersAnyAuthenticatedPacketTraversal()
   462  			peer.timersAnyAuthenticatedPacketReceived()
   463  			peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
   464  
   465  			if len(elem.packet) == 0 {
   466  				device.log.Verbosef("%v - Receiving keepalive packet", peer)
   467  				continue
   468  			}
   469  			peer.timersDataReceived()
   470  
   471  			switch elem.packet[0] >> 4 {
   472  			case 4:
   473  				if len(elem.packet) < ipv4.HeaderLen {
   474  					continue
   475  				}
   476  				field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   477  				length := binary.BigEndian.Uint16(field)
   478  				if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   479  					continue
   480  				}
   481  				elem.packet = elem.packet[:length]
   482  				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   483  				if device.allowedips.Lookup(src) != peer {
   484  					device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   485  					continue
   486  				}
   487  
   488  			case 6:
   489  				if len(elem.packet) < ipv6.HeaderLen {
   490  					continue
   491  				}
   492  				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   493  				length := binary.BigEndian.Uint16(field)
   494  				length += ipv6.HeaderLen
   495  				if int(length) > len(elem.packet) {
   496  					continue
   497  				}
   498  				elem.packet = elem.packet[:length]
   499  				src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   500  				if device.allowedips.Lookup(src) != peer {
   501  					device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   502  					continue
   503  				}
   504  
   505  			default:
   506  				device.log.Verbosef("Packet with invalid IP version from %v", peer)
   507  				continue
   508  			}
   509  
   510  			bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
   511  		}
   512  		if len(bufs) > 0 {
   513  			_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
   514  			if err != nil && !device.isClosed() {
   515  				device.log.Errorf("Failed to write packets to TUN device: %v", err)
   516  			}
   517  		}
   518  		for _, elem := range *elems {
   519  			device.PutMessageBuffer(elem.buffer)
   520  			device.PutInboundElement(elem)
   521  		}
   522  		bufs = bufs[:0]
   523  		device.PutInboundElementsSlice(elems)
   524  	}
   525  }