github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  	"github.com/liloew/wireguard-go/conn"
    21  )
    22  
    23  type QueueHandshakeElement struct {
    24  	msgType  uint32
    25  	packet   []byte
    26  	endpoint conn.Endpoint
    27  	buffer   *[MaxMessageSize]byte
    28  }
    29  
    30  type QueueInboundElement struct {
    31  	sync.Mutex
    32  	buffer   *[MaxMessageSize]byte
    33  	packet   []byte
    34  	counter  uint64
    35  	keypair  *Keypair
    36  	endpoint conn.Endpoint
    37  }
    38  
    39  // clearPointers clears elem fields that contain pointers.
    40  // This makes the garbage collector's life easier and
    41  // avoids accidentally keeping other objects around unnecessarily.
    42  // It also reduces the possible collateral damage from use-after-free bugs.
    43  func (elem *QueueInboundElement) clearPointers() {
    44  	elem.buffer = nil
    45  	elem.packet = nil
    46  	elem.keypair = nil
    47  	elem.endpoint = nil
    48  }
    49  
    50  /* Called when a new authenticated message has been received
    51   *
    52   * NOTE: Not thread safe, but called by sequential receiver!
    53   */
    54  func (peer *Peer) keepKeyFreshReceiving() {
    55  	if peer.timers.sentLastMinuteHandshake.Get() {
    56  		return
    57  	}
    58  	keypair := peer.keypairs.Current()
    59  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    60  		peer.timers.sentLastMinuteHandshake.Set(true)
    61  		peer.SendHandshakeInitiation(false)
    62  	}
    63  }
    64  
    65  /* Receives incoming datagrams for the device
    66   *
    67   * Every time the bind is updated a new routine is started for
    68   * IPv4 and IPv6 (separately)
    69   */
    70  func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
    71  	recvName := recv.PrettyName()
    72  	defer func() {
    73  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    74  		device.queue.decryption.wg.Done()
    75  		device.queue.handshake.wg.Done()
    76  		device.net.stopping.Done()
    77  	}()
    78  
    79  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    80  
    81  	// receive datagrams until conn is closed
    82  
    83  	buffer := device.GetMessageBuffer()
    84  
    85  	var (
    86  		err         error
    87  		size        int
    88  		endpoint    conn.Endpoint
    89  		deathSpiral int
    90  	)
    91  
    92  	for {
    93  		size, endpoint, err = recv(buffer[:])
    94  
    95  		if err != nil {
    96  			device.PutMessageBuffer(buffer)
    97  			if errors.Is(err, net.ErrClosed) {
    98  				return
    99  			}
   100  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   101  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   102  				return
   103  			}
   104  			if deathSpiral < 10 {
   105  				deathSpiral++
   106  				time.Sleep(time.Second / 3)
   107  				buffer = device.GetMessageBuffer()
   108  				continue
   109  			}
   110  			return
   111  		}
   112  		deathSpiral = 0
   113  
   114  		if size < MinMessageSize {
   115  			continue
   116  		}
   117  
   118  		// check size of packet
   119  
   120  		packet := buffer[:size]
   121  		msgType := binary.LittleEndian.Uint32(packet[:4])
   122  
   123  		var okay bool
   124  
   125  		switch msgType {
   126  
   127  		// check if transport
   128  
   129  		case MessageTransportType:
   130  
   131  			// check size
   132  
   133  			if len(packet) < MessageTransportSize {
   134  				continue
   135  			}
   136  
   137  			// lookup key pair
   138  
   139  			receiver := binary.LittleEndian.Uint32(
   140  				packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   141  			)
   142  			value := device.indexTable.Lookup(receiver)
   143  			keypair := value.keypair
   144  			if keypair == nil {
   145  				continue
   146  			}
   147  
   148  			// check keypair expiry
   149  
   150  			if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   151  				continue
   152  			}
   153  
   154  			// create work element
   155  			peer := value.peer
   156  			elem := device.GetInboundElement()
   157  			elem.packet = packet
   158  			elem.buffer = buffer
   159  			elem.keypair = keypair
   160  			elem.endpoint = endpoint
   161  			elem.counter = 0
   162  			elem.Mutex = sync.Mutex{}
   163  			elem.Lock()
   164  
   165  			// add to decryption queues
   166  			if peer.isRunning.Get() {
   167  				peer.queue.inbound.c <- elem
   168  				device.queue.decryption.c <- elem
   169  				buffer = device.GetMessageBuffer()
   170  			} else {
   171  				device.PutInboundElement(elem)
   172  			}
   173  			continue
   174  
   175  		// otherwise it is a fixed size & handshake related packet
   176  
   177  		case MessageInitiationType:
   178  			okay = len(packet) == MessageInitiationSize
   179  
   180  		case MessageResponseType:
   181  			okay = len(packet) == MessageResponseSize
   182  
   183  		case MessageCookieReplyType:
   184  			okay = len(packet) == MessageCookieReplySize
   185  
   186  		default:
   187  			device.log.Verbosef("Received message with unknown type")
   188  		}
   189  
   190  		if okay {
   191  			select {
   192  			case device.queue.handshake.c <- QueueHandshakeElement{
   193  				msgType:  msgType,
   194  				buffer:   buffer,
   195  				packet:   packet,
   196  				endpoint: endpoint,
   197  			}:
   198  				buffer = device.GetMessageBuffer()
   199  			default:
   200  			}
   201  		}
   202  	}
   203  }
   204  
   205  func (device *Device) RoutineDecryption(id int) {
   206  	var nonce [chacha20poly1305.NonceSize]byte
   207  
   208  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   209  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   210  
   211  	for elem := range device.queue.decryption.c {
   212  		// split message into fields
   213  		counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   214  		content := elem.packet[MessageTransportOffsetContent:]
   215  
   216  		// decrypt and release to consumer
   217  		var err error
   218  		elem.counter = binary.LittleEndian.Uint64(counter)
   219  		// copy counter to nonce
   220  		binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   221  		elem.packet, err = elem.keypair.receive.Open(
   222  			content[:0],
   223  			nonce[:],
   224  			content,
   225  			nil,
   226  		)
   227  		if err != nil {
   228  			elem.packet = nil
   229  		}
   230  		elem.Unlock()
   231  	}
   232  }
   233  
   234  /* Handles incoming packets related to handshake
   235   */
   236  func (device *Device) RoutineHandshake(id int) {
   237  	defer func() {
   238  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   239  		device.queue.encryption.wg.Done()
   240  	}()
   241  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   242  
   243  	for elem := range device.queue.handshake.c {
   244  
   245  		// handle cookie fields and ratelimiting
   246  
   247  		switch elem.msgType {
   248  
   249  		case MessageCookieReplyType:
   250  
   251  			// unmarshal packet
   252  
   253  			var reply MessageCookieReply
   254  			reader := bytes.NewReader(elem.packet)
   255  			err := binary.Read(reader, binary.LittleEndian, &reply)
   256  			if err != nil {
   257  				device.log.Verbosef("Failed to decode cookie reply")
   258  				goto skip
   259  			}
   260  
   261  			// lookup peer from index
   262  
   263  			entry := device.indexTable.Lookup(reply.Receiver)
   264  
   265  			if entry.peer == nil {
   266  				goto skip
   267  			}
   268  
   269  			// consume reply
   270  
   271  			if peer := entry.peer; peer.isRunning.Get() {
   272  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
   273  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   274  					device.log.Verbosef("Could not decrypt invalid cookie response")
   275  				}
   276  			}
   277  
   278  			goto skip
   279  
   280  		case MessageInitiationType, MessageResponseType:
   281  
   282  			// check mac fields and maybe ratelimit
   283  
   284  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   285  				device.log.Verbosef("Received packet with invalid mac1")
   286  				goto skip
   287  			}
   288  
   289  			// endpoints destination address is the source of the datagram
   290  
   291  			if device.IsUnderLoad() {
   292  
   293  				// verify MAC2 field
   294  
   295  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   296  					device.SendHandshakeCookie(&elem)
   297  					goto skip
   298  				}
   299  
   300  				// check ratelimiter
   301  
   302  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   303  					goto skip
   304  				}
   305  			}
   306  
   307  		default:
   308  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   309  			goto skip
   310  		}
   311  
   312  		// handle handshake initiation/response content
   313  
   314  		switch elem.msgType {
   315  		case MessageInitiationType:
   316  
   317  			// unmarshal
   318  
   319  			var msg MessageInitiation
   320  			reader := bytes.NewReader(elem.packet)
   321  			err := binary.Read(reader, binary.LittleEndian, &msg)
   322  			if err != nil {
   323  				device.log.Errorf("Failed to decode initiation message")
   324  				goto skip
   325  			}
   326  
   327  			// consume initiation
   328  
   329  			peer := device.ConsumeMessageInitiation(&msg)
   330  			if peer == nil {
   331  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   332  				goto skip
   333  			}
   334  
   335  			// update timers
   336  
   337  			peer.timersAnyAuthenticatedPacketTraversal()
   338  			peer.timersAnyAuthenticatedPacketReceived()
   339  
   340  			// update endpoint
   341  			peer.SetEndpointFromPacket(elem.endpoint)
   342  
   343  			device.log.Verbosef("%v - Received handshake initiation", peer)
   344  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   345  
   346  			peer.SendHandshakeResponse()
   347  
   348  		case MessageResponseType:
   349  
   350  			// unmarshal
   351  
   352  			var msg MessageResponse
   353  			reader := bytes.NewReader(elem.packet)
   354  			err := binary.Read(reader, binary.LittleEndian, &msg)
   355  			if err != nil {
   356  				device.log.Errorf("Failed to decode response message")
   357  				goto skip
   358  			}
   359  
   360  			// consume response
   361  
   362  			peer := device.ConsumeMessageResponse(&msg)
   363  			if peer == nil {
   364  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   365  				goto skip
   366  			}
   367  
   368  			// update endpoint
   369  			peer.SetEndpointFromPacket(elem.endpoint)
   370  
   371  			device.log.Verbosef("%v - Received handshake response", peer)
   372  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   373  
   374  			// update timers
   375  
   376  			peer.timersAnyAuthenticatedPacketTraversal()
   377  			peer.timersAnyAuthenticatedPacketReceived()
   378  
   379  			// derive keypair
   380  
   381  			err = peer.BeginSymmetricSession()
   382  
   383  			if err != nil {
   384  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   385  				goto skip
   386  			}
   387  
   388  			peer.timersSessionDerived()
   389  			peer.timersHandshakeComplete()
   390  			peer.SendKeepalive()
   391  		}
   392  	skip:
   393  		device.PutMessageBuffer(elem.buffer)
   394  	}
   395  }
   396  
   397  func (peer *Peer) RoutineSequentialReceiver() {
   398  	device := peer.device
   399  	defer func() {
   400  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   401  		peer.stopping.Done()
   402  	}()
   403  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   404  
   405  	for elem := range peer.queue.inbound.c {
   406  		if elem == nil {
   407  			return
   408  		}
   409  		var err error
   410  		elem.Lock()
   411  		if elem.packet == nil {
   412  			// decryption failed
   413  			goto skip
   414  		}
   415  
   416  		if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   417  			goto skip
   418  		}
   419  
   420  		peer.SetEndpointFromPacket(elem.endpoint)
   421  		if peer.ReceivedWithKeypair(elem.keypair) {
   422  			peer.timersHandshakeComplete()
   423  			peer.SendStagedPackets()
   424  		}
   425  
   426  		peer.keepKeyFreshReceiving()
   427  		peer.timersAnyAuthenticatedPacketTraversal()
   428  		peer.timersAnyAuthenticatedPacketReceived()
   429  		atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
   430  
   431  		if len(elem.packet) == 0 {
   432  			device.log.Verbosef("%v - Receiving keepalive packet", peer)
   433  			goto skip
   434  		}
   435  		peer.timersDataReceived()
   436  
   437  		switch elem.packet[0] >> 4 {
   438  		case ipv4.Version:
   439  			if len(elem.packet) < ipv4.HeaderLen {
   440  				goto skip
   441  			}
   442  			field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   443  			length := binary.BigEndian.Uint16(field)
   444  			if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   445  				goto skip
   446  			}
   447  			elem.packet = elem.packet[:length]
   448  			src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   449  			if device.allowedips.Lookup(src) != peer {
   450  				device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   451  				goto skip
   452  			}
   453  
   454  		case ipv6.Version:
   455  			if len(elem.packet) < ipv6.HeaderLen {
   456  				goto skip
   457  			}
   458  			field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   459  			length := binary.BigEndian.Uint16(field)
   460  			length += ipv6.HeaderLen
   461  			if int(length) > len(elem.packet) {
   462  				goto skip
   463  			}
   464  			elem.packet = elem.packet[:length]
   465  			src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   466  			if device.allowedips.Lookup(src) != peer {
   467  				device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   468  				goto skip
   469  			}
   470  
   471  		default:
   472  			device.log.Verbosef("Packet with invalid IP version from %v", peer)
   473  			goto skip
   474  		}
   475  
   476  		_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
   477  		if err != nil && !device.isClosed() {
   478  			device.log.Errorf("Failed to write packet to TUN device: %v", err)
   479  		}
   480  		if len(peer.queue.inbound.c) == 0 {
   481  			err = device.tun.device.Flush()
   482  			if err != nil {
   483  				peer.device.log.Errorf("Unable to flush packets: %v", err)
   484  			}
   485  		}
   486  	skip:
   487  		device.PutMessageBuffer(elem.buffer)
   488  		device.PutInboundElement(elem)
   489  	}
   490  }