github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/device/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  
    21  	"github.com/bugfan/wireguard-go/conn"
    22  )
    23  
    24  type QueueHandshakeElement struct {
    25  	msgType  uint32
    26  	packet   []byte
    27  	endpoint conn.Endpoint
    28  	buffer   *[MaxMessageSize]byte
    29  }
    30  
    31  type QueueInboundElement struct {
    32  	sync.Mutex
    33  	buffer   *[MaxMessageSize]byte
    34  	packet   []byte
    35  	counter  uint64
    36  	keypair  *Keypair
    37  	endpoint conn.Endpoint
    38  }
    39  
    40  // clearPointers clears elem fields that contain pointers.
    41  // This makes the garbage collector's life easier and
    42  // avoids accidentally keeping other objects around unnecessarily.
    43  // It also reduces the possible collateral damage from use-after-free bugs.
    44  func (elem *QueueInboundElement) clearPointers() {
    45  	elem.buffer = nil
    46  	elem.packet = nil
    47  	elem.keypair = nil
    48  	elem.endpoint = nil
    49  }
    50  
    51  /* Called when a new authenticated message has been received
    52   *
    53   * NOTE: Not thread safe, but called by sequential receiver!
    54   */
    55  func (peer *Peer) keepKeyFreshReceiving() {
    56  	if peer.timers.sentLastMinuteHandshake.Get() {
    57  		return
    58  	}
    59  	keypair := peer.keypairs.Current()
    60  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    61  		peer.timers.sentLastMinuteHandshake.Set(true)
    62  		peer.SendHandshakeInitiation(false)
    63  	}
    64  }
    65  
    66  /* Receives incoming datagrams for the device
    67   *
    68   * Every time the bind is updated a new routine is started for
    69   * IPv4 and IPv6 (separately)
    70   */
    71  func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
    72  	recvName := recv.PrettyName()
    73  	defer func() {
    74  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    75  		device.queue.decryption.wg.Done()
    76  		device.queue.handshake.wg.Done()
    77  		device.net.stopping.Done()
    78  	}()
    79  
    80  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    81  
    82  	// receive datagrams until conn is closed
    83  
    84  	buffer := device.GetMessageBuffer()
    85  
    86  	var (
    87  		err         error
    88  		size        int
    89  		endpoint    conn.Endpoint
    90  		deathSpiral int
    91  	)
    92  
    93  	for {
    94  		size, endpoint, err = recv(buffer[:])
    95  
    96  		if err != nil {
    97  			device.PutMessageBuffer(buffer)
    98  			if errors.Is(err, net.ErrClosed) {
    99  				return
   100  			}
   101  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   102  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   103  				return
   104  			}
   105  			if deathSpiral < 10 {
   106  				deathSpiral++
   107  				time.Sleep(time.Second / 3)
   108  				buffer = device.GetMessageBuffer()
   109  				continue
   110  			}
   111  			return
   112  		}
   113  		deathSpiral = 0
   114  
   115  		if size < MinMessageSize {
   116  			continue
   117  		}
   118  
   119  		// check size of packet
   120  
   121  		packet := buffer[:size]
   122  		msgType := binary.LittleEndian.Uint32(packet[:4])
   123  
   124  		var okay bool
   125  
   126  		switch msgType {
   127  
   128  		// check if transport
   129  
   130  		case MessageTransportType:
   131  
   132  			// check size
   133  
   134  			if len(packet) < MessageTransportSize {
   135  				continue
   136  			}
   137  
   138  			// lookup key pair
   139  
   140  			receiver := binary.LittleEndian.Uint32(
   141  				packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   142  			)
   143  			value := device.indexTable.Lookup(receiver)
   144  			keypair := value.keypair
   145  			if keypair == nil {
   146  				continue
   147  			}
   148  
   149  			// check keypair expiry
   150  
   151  			if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   152  				continue
   153  			}
   154  
   155  			// create work element
   156  			peer := value.peer
   157  			elem := device.GetInboundElement()
   158  			elem.packet = packet
   159  			elem.buffer = buffer
   160  			elem.keypair = keypair
   161  			elem.endpoint = endpoint
   162  			elem.counter = 0
   163  			elem.Mutex = sync.Mutex{}
   164  			elem.Lock()
   165  
   166  			// add to decryption queues
   167  			if peer.isRunning.Get() {
   168  				peer.queue.inbound.c <- elem
   169  				device.queue.decryption.c <- elem
   170  				buffer = device.GetMessageBuffer()
   171  			} else {
   172  				device.PutInboundElement(elem)
   173  			}
   174  			continue
   175  
   176  		// otherwise it is a fixed size & handshake related packet
   177  
   178  		case MessageInitiationType:
   179  			okay = len(packet) == MessageInitiationSize
   180  
   181  		case MessageResponseType:
   182  			okay = len(packet) == MessageResponseSize
   183  
   184  		case MessageCookieReplyType:
   185  			okay = len(packet) == MessageCookieReplySize
   186  
   187  		default:
   188  			device.log.Verbosef("Received message with unknown type")
   189  		}
   190  
   191  		if okay {
   192  			select {
   193  			case device.queue.handshake.c <- QueueHandshakeElement{
   194  				msgType:  msgType,
   195  				buffer:   buffer,
   196  				packet:   packet,
   197  				endpoint: endpoint,
   198  			}:
   199  				buffer = device.GetMessageBuffer()
   200  			default:
   201  			}
   202  		}
   203  	}
   204  }
   205  
   206  func (device *Device) RoutineDecryption(id int) {
   207  	var nonce [chacha20poly1305.NonceSize]byte
   208  
   209  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   210  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   211  
   212  	for elem := range device.queue.decryption.c {
   213  		// split message into fields
   214  		counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   215  		content := elem.packet[MessageTransportOffsetContent:]
   216  
   217  		// decrypt and release to consumer
   218  		var err error
   219  		elem.counter = binary.LittleEndian.Uint64(counter)
   220  		// copy counter to nonce
   221  		binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   222  		elem.packet, err = elem.keypair.receive.Open(
   223  			content[:0],
   224  			nonce[:],
   225  			content,
   226  			nil,
   227  		)
   228  		if err != nil {
   229  			elem.packet = nil
   230  		}
   231  		elem.Unlock()
   232  	}
   233  }
   234  
   235  /* Handles incoming packets related to handshake
   236   */
   237  func (device *Device) RoutineHandshake(id int) {
   238  	defer func() {
   239  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   240  		device.queue.encryption.wg.Done()
   241  	}()
   242  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   243  
   244  	for elem := range device.queue.handshake.c {
   245  
   246  		// handle cookie fields and ratelimiting
   247  
   248  		switch elem.msgType {
   249  
   250  		case MessageCookieReplyType:
   251  
   252  			// unmarshal packet
   253  
   254  			var reply MessageCookieReply
   255  			reader := bytes.NewReader(elem.packet)
   256  			err := binary.Read(reader, binary.LittleEndian, &reply)
   257  			if err != nil {
   258  				device.log.Verbosef("Failed to decode cookie reply")
   259  				goto skip
   260  			}
   261  
   262  			// lookup peer from index
   263  
   264  			entry := device.indexTable.Lookup(reply.Receiver)
   265  
   266  			if entry.peer == nil {
   267  				goto skip
   268  			}
   269  
   270  			// consume reply
   271  
   272  			if peer := entry.peer; peer.isRunning.Get() {
   273  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
   274  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   275  					device.log.Verbosef("Could not decrypt invalid cookie response")
   276  				}
   277  			}
   278  
   279  			goto skip
   280  
   281  		case MessageInitiationType, MessageResponseType:
   282  
   283  			// check mac fields and maybe ratelimit
   284  
   285  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   286  				device.log.Verbosef("Received packet with invalid mac1")
   287  				goto skip
   288  			}
   289  
   290  			// endpoints destination address is the source of the datagram
   291  
   292  			if device.IsUnderLoad() {
   293  
   294  				// verify MAC2 field
   295  
   296  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   297  					device.SendHandshakeCookie(&elem)
   298  					goto skip
   299  				}
   300  
   301  				// check ratelimiter
   302  
   303  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   304  					goto skip
   305  				}
   306  			}
   307  
   308  		default:
   309  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   310  			goto skip
   311  		}
   312  
   313  		// handle handshake initiation/response content
   314  
   315  		switch elem.msgType {
   316  		case MessageInitiationType:
   317  
   318  			// unmarshal
   319  
   320  			var msg MessageInitiation
   321  			reader := bytes.NewReader(elem.packet)
   322  			err := binary.Read(reader, binary.LittleEndian, &msg)
   323  			if err != nil {
   324  				device.log.Errorf("Failed to decode initiation message")
   325  				goto skip
   326  			}
   327  
   328  			// consume initiation
   329  
   330  			peer := device.ConsumeMessageInitiation(&msg)
   331  			if peer == nil {
   332  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   333  				goto skip
   334  			}
   335  
   336  			// update timers
   337  
   338  			peer.timersAnyAuthenticatedPacketTraversal()
   339  			peer.timersAnyAuthenticatedPacketReceived()
   340  
   341  			// update endpoint
   342  			peer.SetEndpointFromPacket(elem.endpoint)
   343  
   344  			device.log.Verbosef("%v - Received handshake initiation", peer)
   345  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   346  
   347  			peer.SendHandshakeResponse()
   348  
   349  		case MessageResponseType:
   350  
   351  			// unmarshal
   352  
   353  			var msg MessageResponse
   354  			reader := bytes.NewReader(elem.packet)
   355  			err := binary.Read(reader, binary.LittleEndian, &msg)
   356  			if err != nil {
   357  				device.log.Errorf("Failed to decode response message")
   358  				goto skip
   359  			}
   360  
   361  			// consume response
   362  
   363  			peer := device.ConsumeMessageResponse(&msg)
   364  			if peer == nil {
   365  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   366  				goto skip
   367  			}
   368  
   369  			// update endpoint
   370  			peer.SetEndpointFromPacket(elem.endpoint)
   371  
   372  			device.log.Verbosef("%v - Received handshake response", peer)
   373  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   374  
   375  			// update timers
   376  
   377  			peer.timersAnyAuthenticatedPacketTraversal()
   378  			peer.timersAnyAuthenticatedPacketReceived()
   379  
   380  			// derive keypair
   381  
   382  			err = peer.BeginSymmetricSession()
   383  
   384  			if err != nil {
   385  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   386  				goto skip
   387  			}
   388  
   389  			peer.timersSessionDerived()
   390  			peer.timersHandshakeComplete()
   391  			peer.SendKeepalive()
   392  		}
   393  	skip:
   394  		device.PutMessageBuffer(elem.buffer)
   395  	}
   396  }
   397  
   398  func (peer *Peer) RoutineSequentialReceiver() {
   399  	device := peer.device
   400  	defer func() {
   401  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   402  		peer.stopping.Done()
   403  	}()
   404  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   405  
   406  	for elem := range peer.queue.inbound.c {
   407  		if elem == nil {
   408  			return
   409  		}
   410  		var err error
   411  		elem.Lock()
   412  		if elem.packet == nil {
   413  			// decryption failed
   414  			goto skip
   415  		}
   416  
   417  		if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   418  			goto skip
   419  		}
   420  
   421  		peer.SetEndpointFromPacket(elem.endpoint)
   422  		if peer.ReceivedWithKeypair(elem.keypair) {
   423  			peer.timersHandshakeComplete()
   424  			peer.SendStagedPackets()
   425  		}
   426  
   427  		peer.keepKeyFreshReceiving()
   428  		peer.timersAnyAuthenticatedPacketTraversal()
   429  		peer.timersAnyAuthenticatedPacketReceived()
   430  		atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
   431  
   432  		if len(elem.packet) == 0 {
   433  			device.log.Verbosef("%v - Receiving keepalive packet", peer)
   434  			goto skip
   435  		}
   436  		peer.timersDataReceived()
   437  
   438  		switch elem.packet[0] >> 4 {
   439  		case ipv4.Version:
   440  			if len(elem.packet) < ipv4.HeaderLen {
   441  				goto skip
   442  			}
   443  			field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   444  			length := binary.BigEndian.Uint16(field)
   445  			if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   446  				goto skip
   447  			}
   448  			elem.packet = elem.packet[:length]
   449  			src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   450  			if device.allowedips.Lookup(src) != peer {
   451  				device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   452  				goto skip
   453  			}
   454  
   455  		case ipv6.Version:
   456  			if len(elem.packet) < ipv6.HeaderLen {
   457  				goto skip
   458  			}
   459  			field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   460  			length := binary.BigEndian.Uint16(field)
   461  			length += ipv6.HeaderLen
   462  			if int(length) > len(elem.packet) {
   463  				goto skip
   464  			}
   465  			elem.packet = elem.packet[:length]
   466  			src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   467  			if device.allowedips.Lookup(src) != peer {
   468  				device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   469  				goto skip
   470  			}
   471  
   472  		default:
   473  			device.log.Verbosef("Packet with invalid IP version from %v", peer)
   474  			goto skip
   475  		}
   476  
   477  		_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
   478  		if err != nil && !device.isClosed() {
   479  			device.log.Errorf("Failed to write packet to TUN device: %v", err)
   480  		}
   481  		if len(peer.queue.inbound.c) == 0 {
   482  			err = device.tun.device.Flush()
   483  			if err != nil {
   484  				peer.device.log.Errorf("Unable to flush packets: %v", err)
   485  			}
   486  		}
   487  	skip:
   488  		device.PutMessageBuffer(elem.buffer)
   489  		device.PutInboundElement(elem)
   490  	}
   491  }