github.com/GFW-knocker/wireguard@v1.0.1/device/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/GFW-knocker/wireguard/conn"
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  )
    21  
    22  type QueueHandshakeElement struct {
    23  	msgType  uint32
    24  	packet   []byte
    25  	endpoint conn.Endpoint
    26  	buffer   *[MaxMessageSize]byte
    27  }
    28  
    29  type QueueInboundElement struct {
    30  	buffer   *[MaxMessageSize]byte
    31  	packet   []byte
    32  	counter  uint64
    33  	keypair  *Keypair
    34  	endpoint conn.Endpoint
    35  }
    36  
    37  type QueueInboundElementsContainer struct {
    38  	sync.Mutex
    39  	elems []*QueueInboundElement
    40  }
    41  
    42  // clearPointers clears elem fields that contain pointers.
    43  // This makes the garbage collector's life easier and
    44  // avoids accidentally keeping other objects around unnecessarily.
    45  // It also reduces the possible collateral damage from use-after-free bugs.
    46  func (elem *QueueInboundElement) clearPointers() {
    47  	elem.buffer = nil
    48  	elem.packet = nil
    49  	elem.keypair = nil
    50  	elem.endpoint = nil
    51  }
    52  
    53  /* Called when a new authenticated message has been received
    54   *
    55   * NOTE: Not thread safe, but called by sequential receiver!
    56   */
    57  func (peer *Peer) keepKeyFreshReceiving() {
    58  	if peer.timers.sentLastMinuteHandshake.Load() {
    59  		return
    60  	}
    61  	keypair := peer.keypairs.Current()
    62  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    63  		peer.timers.sentLastMinuteHandshake.Store(true)
    64  		peer.SendHandshakeInitiation(false)
    65  	}
    66  }
    67  
    68  /* Receives incoming datagrams for the device
    69   *
    70   * Every time the bind is updated a new routine is started for
    71   * IPv4 and IPv6 (separately)
    72   */
    73  func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
    74  	recvName := recv.PrettyName()
    75  	defer func() {
    76  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    77  		device.queue.decryption.wg.Done()
    78  		device.queue.handshake.wg.Done()
    79  		device.net.stopping.Done()
    80  	}()
    81  
    82  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    83  
    84  	// receive datagrams until conn is closed
    85  
    86  	var (
    87  		bufsArrs    = make([]*[MaxMessageSize]byte, maxBatchSize)
    88  		bufs        = make([][]byte, maxBatchSize)
    89  		err         error
    90  		sizes       = make([]int, maxBatchSize)
    91  		count       int
    92  		endpoints   = make([]conn.Endpoint, maxBatchSize)
    93  		deathSpiral int
    94  		elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
    95  	)
    96  
    97  	for i := range bufsArrs {
    98  		bufsArrs[i] = device.GetMessageBuffer()
    99  		bufs[i] = bufsArrs[i][:]
   100  	}
   101  
   102  	defer func() {
   103  		for i := 0; i < maxBatchSize; i++ {
   104  			if bufsArrs[i] != nil {
   105  				device.PutMessageBuffer(bufsArrs[i])
   106  			}
   107  		}
   108  	}()
   109  
   110  	for {
   111  		count, err = recv(bufs, sizes, endpoints)
   112  		if err != nil {
   113  			if errors.Is(err, net.ErrClosed) {
   114  				return
   115  			}
   116  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   117  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   118  				return
   119  			}
   120  			if deathSpiral < 10 {
   121  				deathSpiral++
   122  				time.Sleep(time.Second / 3)
   123  				continue
   124  			}
   125  			return
   126  		}
   127  		deathSpiral = 0
   128  
   129  		// handle each packet in the batch
   130  		for i, size := range sizes[:count] {
   131  			if size < MinMessageSize {
   132  				continue
   133  			}
   134  
   135  			// check size of packet
   136  
   137  			packet := bufsArrs[i][:size]
   138  			msgType := binary.LittleEndian.Uint32(packet[:4])
   139  
   140  			switch msgType {
   141  
   142  			// check if transport
   143  
   144  			case MessageTransportType:
   145  
   146  				// check size
   147  
   148  				if len(packet) < MessageTransportSize {
   149  					continue
   150  				}
   151  
   152  				// lookup key pair
   153  
   154  				receiver := binary.LittleEndian.Uint32(
   155  					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   156  				)
   157  				value := device.indexTable.Lookup(receiver)
   158  				keypair := value.keypair
   159  				if keypair == nil {
   160  					continue
   161  				}
   162  
   163  				// check keypair expiry
   164  
   165  				if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   166  					continue
   167  				}
   168  
   169  				// create work element
   170  				peer := value.peer
   171  				elem := device.GetInboundElement()
   172  				elem.packet = packet
   173  				elem.buffer = bufsArrs[i]
   174  				elem.keypair = keypair
   175  				elem.endpoint = endpoints[i]
   176  				elem.counter = 0
   177  
   178  				elemsForPeer, ok := elemsByPeer[peer]
   179  				if !ok {
   180  					elemsForPeer = device.GetInboundElementsContainer()
   181  					elemsForPeer.Lock()
   182  					elemsByPeer[peer] = elemsForPeer
   183  				}
   184  				elemsForPeer.elems = append(elemsForPeer.elems, elem)
   185  				bufsArrs[i] = device.GetMessageBuffer()
   186  				bufs[i] = bufsArrs[i][:]
   187  				continue
   188  
   189  			// otherwise it is a fixed size & handshake related packet
   190  
   191  			case MessageInitiationType:
   192  				if len(packet) != MessageInitiationSize {
   193  					continue
   194  				}
   195  
   196  			case MessageResponseType:
   197  				if len(packet) != MessageResponseSize {
   198  					continue
   199  				}
   200  
   201  			case MessageCookieReplyType:
   202  				if len(packet) != MessageCookieReplySize {
   203  					continue
   204  				}
   205  
   206  			default:
   207  				device.log.Verbosef("Received message with unknown type")
   208  				continue
   209  			}
   210  
   211  			select {
   212  			case device.queue.handshake.c <- QueueHandshakeElement{
   213  				msgType:  msgType,
   214  				buffer:   bufsArrs[i],
   215  				packet:   packet,
   216  				endpoint: endpoints[i],
   217  			}:
   218  				bufsArrs[i] = device.GetMessageBuffer()
   219  				bufs[i] = bufsArrs[i][:]
   220  			default:
   221  			}
   222  		}
   223  		for peer, elemsContainer := range elemsByPeer {
   224  			if peer.isRunning.Load() {
   225  				peer.queue.inbound.c <- elemsContainer
   226  				device.queue.decryption.c <- elemsContainer
   227  			} else {
   228  				for _, elem := range elemsContainer.elems {
   229  					device.PutMessageBuffer(elem.buffer)
   230  					device.PutInboundElement(elem)
   231  				}
   232  				device.PutInboundElementsContainer(elemsContainer)
   233  			}
   234  			delete(elemsByPeer, peer)
   235  		}
   236  	}
   237  }
   238  
   239  func (device *Device) RoutineDecryption(id int) {
   240  	var nonce [chacha20poly1305.NonceSize]byte
   241  
   242  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   243  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   244  
   245  	for elemsContainer := range device.queue.decryption.c {
   246  		for _, elem := range elemsContainer.elems {
   247  			// split message into fields
   248  			counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   249  			content := elem.packet[MessageTransportOffsetContent:]
   250  
   251  			// decrypt and release to consumer
   252  			var err error
   253  			elem.counter = binary.LittleEndian.Uint64(counter)
   254  			// copy counter to nonce
   255  			binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   256  			elem.packet, err = elem.keypair.receive.Open(
   257  				content[:0],
   258  				nonce[:],
   259  				content,
   260  				nil,
   261  			)
   262  			if err != nil {
   263  				elem.packet = nil
   264  			}
   265  		}
   266  		elemsContainer.Unlock()
   267  	}
   268  }
   269  
   270  /* Handles incoming packets related to handshake
   271   */
   272  func (device *Device) RoutineHandshake(id int) {
   273  	defer func() {
   274  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   275  		device.queue.encryption.wg.Done()
   276  	}()
   277  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   278  
   279  	for elem := range device.queue.handshake.c {
   280  
   281  		// handle cookie fields and ratelimiting
   282  
   283  		switch elem.msgType {
   284  
   285  		case MessageCookieReplyType:
   286  
   287  			// unmarshal packet
   288  
   289  			var reply MessageCookieReply
   290  			reader := bytes.NewReader(elem.packet)
   291  			err := binary.Read(reader, binary.LittleEndian, &reply)
   292  			if err != nil {
   293  				device.log.Verbosef("Failed to decode cookie reply")
   294  				goto skip
   295  			}
   296  
   297  			// lookup peer from index
   298  
   299  			entry := device.indexTable.Lookup(reply.Receiver)
   300  
   301  			if entry.peer == nil {
   302  				goto skip
   303  			}
   304  
   305  			// consume reply
   306  
   307  			if peer := entry.peer; peer.isRunning.Load() {
   308  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
   309  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   310  					device.log.Verbosef("Could not decrypt invalid cookie response")
   311  				}
   312  			}
   313  
   314  			goto skip
   315  
   316  		case MessageInitiationType, MessageResponseType:
   317  
   318  			// check mac fields and maybe ratelimit
   319  
   320  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   321  				device.log.Verbosef("Received packet with invalid mac1")
   322  				goto skip
   323  			}
   324  
   325  			// endpoints destination address is the source of the datagram
   326  
   327  			if device.IsUnderLoad() {
   328  
   329  				// verify MAC2 field
   330  
   331  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   332  					device.SendHandshakeCookie(&elem)
   333  					goto skip
   334  				}
   335  
   336  				// check ratelimiter
   337  
   338  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   339  					goto skip
   340  				}
   341  			}
   342  
   343  		default:
   344  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   345  			goto skip
   346  		}
   347  
   348  		// handle handshake initiation/response content
   349  
   350  		switch elem.msgType {
   351  		case MessageInitiationType:
   352  
   353  			// unmarshal
   354  
   355  			var msg MessageInitiation
   356  			reader := bytes.NewReader(elem.packet)
   357  			err := binary.Read(reader, binary.LittleEndian, &msg)
   358  			if err != nil {
   359  				device.log.Errorf("Failed to decode initiation message")
   360  				goto skip
   361  			}
   362  
   363  			// consume initiation
   364  
   365  			peer := device.ConsumeMessageInitiation(&msg)
   366  			if peer == nil {
   367  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   368  				goto skip
   369  			}
   370  
   371  			// update timers
   372  
   373  			peer.timersAnyAuthenticatedPacketTraversal()
   374  			peer.timersAnyAuthenticatedPacketReceived()
   375  
   376  			// update endpoint
   377  			peer.SetEndpointFromPacket(elem.endpoint)
   378  
   379  			device.log.Verbosef("%v - Received handshake initiation", peer)
   380  			peer.rxBytes.Add(uint64(len(elem.packet)))
   381  
   382  			peer.SendHandshakeResponse()
   383  
   384  		case MessageResponseType:
   385  
   386  			// unmarshal
   387  
   388  			var msg MessageResponse
   389  			reader := bytes.NewReader(elem.packet)
   390  			err := binary.Read(reader, binary.LittleEndian, &msg)
   391  			if err != nil {
   392  				device.log.Errorf("Failed to decode response message")
   393  				goto skip
   394  			}
   395  
   396  			// consume response
   397  
   398  			peer := device.ConsumeMessageResponse(&msg)
   399  			if peer == nil {
   400  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   401  				goto skip
   402  			}
   403  
   404  			// update endpoint
   405  			peer.SetEndpointFromPacket(elem.endpoint)
   406  
   407  			device.log.Verbosef("%v - Received handshake response", peer)
   408  			peer.rxBytes.Add(uint64(len(elem.packet)))
   409  
   410  			// update timers
   411  
   412  			peer.timersAnyAuthenticatedPacketTraversal()
   413  			peer.timersAnyAuthenticatedPacketReceived()
   414  
   415  			// derive keypair
   416  
   417  			err = peer.BeginSymmetricSession()
   418  
   419  			if err != nil {
   420  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   421  				goto skip
   422  			}
   423  
   424  			peer.timersSessionDerived()
   425  			peer.timersHandshakeComplete()
   426  			peer.SendKeepalive()
   427  		}
   428  	skip:
   429  		device.PutMessageBuffer(elem.buffer)
   430  	}
   431  }
   432  
   433  func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
   434  	device := peer.device
   435  	defer func() {
   436  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   437  		peer.stopping.Done()
   438  	}()
   439  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   440  
   441  	bufs := make([][]byte, 0, maxBatchSize)
   442  
   443  	for elemsContainer := range peer.queue.inbound.c {
   444  		if elemsContainer == nil {
   445  			return
   446  		}
   447  		elemsContainer.Lock()
   448  		validTailPacket := -1
   449  		dataPacketReceived := false
   450  		rxBytesLen := uint64(0)
   451  		for i, elem := range elemsContainer.elems {
   452  			if elem.packet == nil {
   453  				// decryption failed
   454  				continue
   455  			}
   456  
   457  			if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   458  				continue
   459  			}
   460  
   461  			validTailPacket = i
   462  			if peer.ReceivedWithKeypair(elem.keypair) {
   463  				peer.SetEndpointFromPacket(elem.endpoint)
   464  				peer.timersHandshakeComplete()
   465  				peer.SendStagedPackets()
   466  			}
   467  			rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
   468  
   469  			if len(elem.packet) == 0 {
   470  				device.log.Verbosef("%v - Receiving keepalive packet", peer)
   471  				continue
   472  			}
   473  			dataPacketReceived = true
   474  
   475  			switch elem.packet[0] >> 4 {
   476  			case 4:
   477  				if len(elem.packet) < ipv4.HeaderLen {
   478  					continue
   479  				}
   480  				field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   481  				length := binary.BigEndian.Uint16(field)
   482  				if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   483  					continue
   484  				}
   485  				elem.packet = elem.packet[:length]
   486  				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   487  				if device.allowedips.Lookup(src) != peer {
   488  					device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   489  					continue
   490  				}
   491  
   492  			case 6:
   493  				if len(elem.packet) < ipv6.HeaderLen {
   494  					continue
   495  				}
   496  				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   497  				length := binary.BigEndian.Uint16(field)
   498  				length += ipv6.HeaderLen
   499  				if int(length) > len(elem.packet) {
   500  					continue
   501  				}
   502  				elem.packet = elem.packet[:length]
   503  				src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   504  				if device.allowedips.Lookup(src) != peer {
   505  					device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   506  					continue
   507  				}
   508  
   509  			default:
   510  				device.log.Verbosef("Packet with invalid IP version from %v", peer)
   511  				continue
   512  			}
   513  
   514  			bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
   515  		}
   516  
   517  		peer.rxBytes.Add(rxBytesLen)
   518  		if validTailPacket >= 0 {
   519  			peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
   520  			peer.keepKeyFreshReceiving()
   521  			peer.timersAnyAuthenticatedPacketTraversal()
   522  			peer.timersAnyAuthenticatedPacketReceived()
   523  		}
   524  		if dataPacketReceived {
   525  			peer.timersDataReceived()
   526  		}
   527  		if len(bufs) > 0 {
   528  			_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
   529  			if err != nil && !device.isClosed() {
   530  				device.log.Errorf("Failed to write packets to TUN device: %v", err)
   531  			}
   532  		}
   533  		for _, elem := range elemsContainer.elems {
   534  			device.PutMessageBuffer(elem.buffer)
   535  			device.PutInboundElement(elem)
   536  		}
   537  		bufs = bufs[:0]
   538  		device.PutInboundElementsContainer(elemsContainer)
   539  	}
   540  }