github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/receive.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"net"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16          "github.com/aead/chacha20"
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  	"github.com/cawidtu/notwireguard-go/conn"
    21  )
    22  
    23  type QueueHandshakeElement struct {
    24  	msgType  uint32
    25  	packet   []byte
    26  	endpoint conn.Endpoint
    27  	buffer   *[MaxMessageSize]byte
    28  }
    29  
    30  type QueueInboundElement struct {
    31  	sync.Mutex
    32  	buffer   *[MaxMessageSize]byte
    33  	packet   []byte
    34  	counter  uint64
    35  	keypair  *Keypair
    36  	endpoint conn.Endpoint
    37  }
    38  
    39  // clearPointers clears elem fields that contain pointers.
    40  // This makes the garbage collector's life easier and
    41  // avoids accidentally keeping other objects around unnecessarily.
    42  // It also reduces the possible collateral damage from use-after-free bugs.
    43  func (elem *QueueInboundElement) clearPointers() {
    44  	elem.buffer = nil
    45  	elem.packet = nil
    46  	elem.keypair = nil
    47  	elem.endpoint = nil
    48  }
    49  
    50  func wgDeobfuscatePacket(obfuscator [NoisePublicKeySize]byte, buf []byte, len int) {
    51  
    52      decryptLength := min(len&0xFFFC, NoiseObfuscateLenMax) - 4 // sizeof(u32)
    53  
    54      // Initialize the cipher with the obfuscator as the key and the nonce derived from the packet contents.
    55  
    56      var nonce [8]byte;
    57  
    58      copy(nonce[0:4], buf[decryptLength : decryptLength + 4])
    59  
    60      state, err :=  chacha20.NewCipher(nonce[:],obfuscator[:])
    61      if err != nil {
    62          panic(err)
    63      }
    64  
    65      state.XORKeyStream(buf[:decryptLength], buf[:decryptLength])
    66  }
    67  
    68  
    69  /* Called when a new authenticated message has been received
    70   *
    71   * NOTE: Not thread safe, but called by sequential receiver!
    72   */
    73  func (peer *Peer) keepKeyFreshReceiving() {
    74  	if peer.timers.sentLastMinuteHandshake.Get() {
    75  		return
    76  	}
    77  	keypair := peer.keypairs.Current()
    78  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
    79  		peer.timers.sentLastMinuteHandshake.Set(true)
    80  		peer.SendHandshakeInitiation(false)
    81  	}
    82  }
    83  
    84  /* Receives incoming datagrams for the device
    85   *
    86   * Every time the bind is updated a new routine is started for
    87   * IPv4 and IPv6 (separately)
    88   */
    89  func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
    90  	recvName := recv.PrettyName()
    91  	defer func() {
    92  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
    93  		device.queue.decryption.wg.Done()
    94  		device.queue.handshake.wg.Done()
    95  		device.net.stopping.Done()
    96  	}()
    97  
    98  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
    99  
   100  	// receive datagrams until conn is closed
   101  
   102  	buffer := device.GetMessageBuffer()
   103  
   104  	var (
   105  		err         error
   106  		size        int
   107  		endpoint    conn.Endpoint
   108  		deathSpiral int
   109  	)
   110  
   111  	for {
   112  		size, endpoint, err = recv(buffer[:])
   113  
   114  		if err != nil {
   115  			device.PutMessageBuffer(buffer)
   116  			if errors.Is(err, net.ErrClosed) {
   117  				return
   118  			}
   119  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
   120  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
   121  				return
   122  			}
   123  			if deathSpiral < 10 {
   124  				deathSpiral++
   125  				time.Sleep(time.Second / 3)
   126  				buffer = device.GetMessageBuffer()
   127  				continue
   128  			}
   129  			return
   130  		}
   131  		deathSpiral = 0
   132  
   133  		if size < MinMessageSize {
   134  			continue
   135  		}
   136  
   137  		// deobfuscate packet
   138  
   139  		wgDeobfuscatePacket(device.staticIdentity.obfuscator, buffer[:], size)
   140  		packet := buffer[:size]
   141  		msgType := binary.LittleEndian.Uint32(packet[:4])
   142  
   143  		var okay bool
   144  		var trim bool
   145  		var headerLen int
   146  
   147  		switch msgType {
   148  
   149  		// check if transport
   150  
   151  		case MessageTransportType:
   152  
   153  			// check size
   154  
   155  			if len(packet) < MessageTransportSize {
   156  				continue
   157  			}
   158  
   159  			// lookup key pair
   160  
   161  			receiver := binary.LittleEndian.Uint32(
   162  				packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
   163  			)
   164  			value := device.indexTable.Lookup(receiver)
   165  			keypair := value.keypair
   166  			if keypair == nil {
   167  				continue
   168  			}
   169  
   170  			// check keypair expiry
   171  
   172  			if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
   173  				continue
   174  			}
   175  
   176  			// create work element
   177  			peer := value.peer
   178  			elem := device.GetInboundElement()
   179  			elem.packet = packet
   180  			elem.buffer = buffer
   181  			elem.keypair = keypair
   182  			elem.endpoint = endpoint
   183  			elem.counter = 0
   184  			elem.Mutex = sync.Mutex{}
   185  			elem.Lock()
   186  
   187  			// add to decryption queues
   188  			if peer.isRunning.Get() {
   189  				peer.queue.inbound.c <- elem
   190  				device.queue.decryption.c <- elem
   191  				buffer = device.GetMessageBuffer()
   192  			} else {
   193  				device.PutInboundElement(elem)
   194  			}
   195  			continue
   196  
   197  		// otherwise it is a "fixed size" & handshake related packet
   198  
   199  	        // check minimum length instead of fixed length for obfuscation	
   200  
   201  		case MessageInitiationType:
   202  			okay = len(packet) >= MessageInitiationSize && len(packet) <= NoiseObfuscateLenMax
   203  			trim = len(packet) > MessageInitiationSize
   204  			headerLen = MessageInitiationSize
   205  
   206  		case MessageResponseType:
   207  			okay = len(packet) >= MessageResponseSize && len(packet) <= NoiseObfuscateLenMax
   208                          trim = len(packet) > MessageResponseSize
   209  			headerLen = MessageResponseSize
   210  
   211  		case MessageCookieReplyType:
   212  			okay = len(packet) >= MessageCookieReplySize && len(packet) <= NoiseObfuscateLenMax
   213                          trim = len(packet) > MessageCookieReplySize
   214                          headerLen = MessageCookieReplySize
   215  
   216  		default:
   217  			device.log.Verbosef("Received message with unknown type")
   218  		}
   219  
   220                  // remove the junk added in obfuscation, if necessary
   221  		if trim {
   222  			packet = packet[:headerLen]
   223  		}
   224  
   225  
   226  
   227  		if okay {
   228  			select {
   229  			case device.queue.handshake.c <- QueueHandshakeElement{
   230  				msgType:  msgType,
   231  				buffer:   buffer,
   232  				packet:   packet,
   233  				endpoint: endpoint,
   234  			}:
   235  				buffer = device.GetMessageBuffer()
   236  			default:
   237  			}
   238  		}
   239  	}
   240  }
   241  
   242  func (device *Device) RoutineDecryption(id int) {
   243  	var nonce [chacha20poly1305.NonceSize]byte
   244  
   245  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
   246  	device.log.Verbosef("Routine: decryption worker %d - started", id)
   247  
   248  	for elem := range device.queue.decryption.c {
   249  		// split message into fields
   250  		counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
   251  		content := elem.packet[MessageTransportOffsetContent:]
   252  
   253  		// decrypt and release to consumer
   254  		var err error
   255  		elem.counter = binary.LittleEndian.Uint64(counter)
   256  		// copy counter to nonce
   257  		binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
   258  		elem.packet, err = elem.keypair.receive.Open(
   259  			content[:0],
   260  			nonce[:],
   261  			content,
   262  			nil,
   263  		)
   264  		if err != nil {
   265  			elem.packet = nil
   266  		}
   267  		elem.Unlock()
   268  	}
   269  }
   270  
   271  /* Handles incoming packets related to handshake
   272   */
   273  func (device *Device) RoutineHandshake(id int) {
   274  	defer func() {
   275  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
   276  		device.queue.encryption.wg.Done()
   277  	}()
   278  	device.log.Verbosef("Routine: handshake worker %d - started", id)
   279  
   280  	for elem := range device.queue.handshake.c {
   281  
   282  		// handle cookie fields and ratelimiting
   283  
   284  		switch elem.msgType {
   285  
   286  		case MessageCookieReplyType:
   287  
   288  			// unmarshal packet
   289  
   290  			var reply MessageCookieReply
   291  			reader := bytes.NewReader(elem.packet)
   292  			err := binary.Read(reader, binary.LittleEndian, &reply)
   293  			if err != nil {
   294  				device.log.Verbosef("Failed to decode cookie reply")
   295  				goto skip
   296  			}
   297  
   298  			// lookup peer from index
   299  
   300  			entry := device.indexTable.Lookup(reply.Receiver)
   301  
   302  			if entry.peer == nil {
   303  				goto skip
   304  			}
   305  
   306  			// consume reply
   307  
   308  			if peer := entry.peer; peer.isRunning.Get() {
   309  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
   310  				if !peer.cookieGenerator.ConsumeReply(&reply) {
   311  					device.log.Verbosef("Could not decrypt invalid cookie response")
   312  				}
   313  			}
   314  
   315  			goto skip
   316  
   317  		case MessageInitiationType, MessageResponseType:
   318  
   319  			// check mac fields and maybe ratelimit
   320  
   321  			if !device.cookieChecker.CheckMAC1(elem.packet) {
   322  				device.log.Verbosef("Received packet with invalid mac1")
   323  				goto skip
   324  			}
   325  
   326  			// endpoints destination address is the source of the datagram
   327  
   328  			if device.IsUnderLoad() {
   329  
   330  				// verify MAC2 field
   331  
   332  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
   333                                          // extract the obfuscator key from the packet received
   334                                          var obfuscator [32]byte
   335                                          copy(obfuscator[:], elem.packet[8:40])
   336  					device.SendHandshakeCookie(&elem, obfuscator) 
   337  					goto skip
   338  				}
   339  
   340  				// check ratelimiter
   341  
   342  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
   343  					goto skip
   344  				}
   345  			}
   346  
   347  		default:
   348  			device.log.Errorf("Invalid packet ended up in the handshake queue")
   349  			goto skip
   350  		}
   351  
   352  		// handle handshake initiation/response content
   353  
   354  		switch elem.msgType {
   355  		case MessageInitiationType:
   356  
   357  			// unmarshal
   358  
   359  			var msg MessageInitiation
   360  			reader := bytes.NewReader(elem.packet)
   361  			err := binary.Read(reader, binary.LittleEndian, &msg)
   362  			if err != nil {
   363  				device.log.Errorf("Failed to decode initiation message")
   364  				goto skip
   365  			}
   366  
   367  			// consume initiation
   368  
   369  			peer := device.ConsumeMessageInitiation(&msg)
   370  			if peer == nil {
   371  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
   372  				goto skip
   373  			}
   374  
   375  			// update timers
   376  
   377  			peer.timersAnyAuthenticatedPacketTraversal()
   378  			peer.timersAnyAuthenticatedPacketReceived()
   379  
   380  			// update endpoint
   381  			peer.SetEndpointFromPacket(elem.endpoint)
   382  
   383  			device.log.Verbosef("%v - Received handshake initiation", peer)
   384  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   385  
   386  			peer.SendHandshakeResponse()
   387  
   388  		case MessageResponseType:
   389  
   390  			// unmarshal
   391  
   392  			var msg MessageResponse
   393  			reader := bytes.NewReader(elem.packet)
   394  			err := binary.Read(reader, binary.LittleEndian, &msg)
   395  			if err != nil {
   396  				device.log.Errorf("Failed to decode response message")
   397  				goto skip
   398  			}
   399  
   400  			// consume response
   401  
   402  			peer := device.ConsumeMessageResponse(&msg)
   403  			if peer == nil {
   404  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
   405  				goto skip
   406  			}
   407  
   408  			// update endpoint
   409  			peer.SetEndpointFromPacket(elem.endpoint)
   410  
   411  			device.log.Verbosef("%v - Received handshake response", peer)
   412  			atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
   413  
   414  			// update timers
   415  
   416  			peer.timersAnyAuthenticatedPacketTraversal()
   417  			peer.timersAnyAuthenticatedPacketReceived()
   418  
   419  			// derive keypair
   420  
   421  			err = peer.BeginSymmetricSession()
   422  
   423  			if err != nil {
   424  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
   425  				goto skip
   426  			}
   427  
   428  			peer.timersSessionDerived()
   429  			peer.timersHandshakeComplete()
   430  			peer.SendKeepalive()
   431  		}
   432  	skip:
   433  		device.PutMessageBuffer(elem.buffer)
   434  	}
   435  }
   436  
   437  func (peer *Peer) RoutineSequentialReceiver() {
   438  	device := peer.device
   439  	defer func() {
   440  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
   441  		peer.stopping.Done()
   442  	}()
   443  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
   444  
   445  	for elem := range peer.queue.inbound.c {
   446  		if elem == nil {
   447  			return
   448  		}
   449  		var err error
   450  		elem.Lock()
   451  		if elem.packet == nil {
   452  			// decryption failed
   453  			goto skip
   454  		}
   455  
   456  		if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
   457  			goto skip
   458  		}
   459  
   460  		peer.SetEndpointFromPacket(elem.endpoint)
   461  		if peer.ReceivedWithKeypair(elem.keypair) {
   462  			peer.timersHandshakeComplete()
   463  			peer.SendStagedPackets()
   464  		}
   465  
   466  		peer.keepKeyFreshReceiving()
   467  		peer.timersAnyAuthenticatedPacketTraversal()
   468  		peer.timersAnyAuthenticatedPacketReceived()
   469  		atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
   470  
   471  		if len(elem.packet) == 0 {
   472  			device.log.Verbosef("%v - Receiving keepalive packet", peer)
   473  			goto skip
   474  		}
   475  		peer.timersDataReceived()
   476  
   477  		switch elem.packet[0] >> 4 {
   478  		case ipv4.Version:
   479  			if len(elem.packet) < ipv4.HeaderLen {
   480  				goto skip
   481  			}
   482  			field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
   483  			length := binary.BigEndian.Uint16(field)
   484  			if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
   485  				goto skip
   486  			}
   487  			elem.packet = elem.packet[:length]
   488  			src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
   489  			if device.allowedips.Lookup(src) != peer {
   490  				device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
   491  				goto skip
   492  			}
   493  
   494  		case ipv6.Version:
   495  			if len(elem.packet) < ipv6.HeaderLen {
   496  				goto skip
   497  			}
   498  			field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
   499  			length := binary.BigEndian.Uint16(field)
   500  			length += ipv6.HeaderLen
   501  			if int(length) > len(elem.packet) {
   502  				goto skip
   503  			}
   504  			elem.packet = elem.packet[:length]
   505  			src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
   506  			if device.allowedips.Lookup(src) != peer {
   507  				device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
   508  				goto skip
   509  			}
   510  
   511  		default:
   512  			device.log.Verbosef("Packet with invalid IP version from %v", peer)
   513  			goto skip
   514  		}
   515  
   516  		_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
   517  		if err != nil && !device.isClosed() {
   518  			device.log.Errorf("Failed to write packet to TUN device: %v", err)
   519  		}
   520  		if len(peer.queue.inbound.c) == 0 {
   521  			err = device.tun.device.Flush()
   522  			if err != nil {
   523  				peer.device.log.Errorf("Unable to flush packets: %v", err)
   524  			}
   525  		}
   526  	skip:
   527  		device.PutMessageBuffer(elem.buffer)
   528  		device.PutInboundElement(elem)
   529  	}
   530  }