github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/noise-protocol.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  	"golang.org/x/crypto/poly1305"
    17  
    18  	"github.com/amnezia-vpn/amneziawg-go/tai64n"
    19  )
    20  
    21  type handshakeState int
    22  
    23  const (
    24  	handshakeZeroed = handshakeState(iota)
    25  	handshakeInitiationCreated
    26  	handshakeInitiationConsumed
    27  	handshakeResponseCreated
    28  	handshakeResponseConsumed
    29  )
    30  
    31  func (hs handshakeState) String() string {
    32  	switch hs {
    33  	case handshakeZeroed:
    34  		return "handshakeZeroed"
    35  	case handshakeInitiationCreated:
    36  		return "handshakeInitiationCreated"
    37  	case handshakeInitiationConsumed:
    38  		return "handshakeInitiationConsumed"
    39  	case handshakeResponseCreated:
    40  		return "handshakeResponseCreated"
    41  	case handshakeResponseConsumed:
    42  		return "handshakeResponseConsumed"
    43  	default:
    44  		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
    45  	}
    46  }
    47  
    48  const (
    49  	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
    50  	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
    51  	WGLabelMAC1       = "mac1----"
    52  	WGLabelCookie     = "cookie--"
    53  )
    54  
    55  var (
    56  	MessageInitiationType  uint32 = 1
    57  	MessageResponseType    uint32 = 2
    58  	MessageCookieReplyType uint32 = 3
    59  	MessageTransportType   uint32 = 4
    60  )
    61  
    62  const (
    63  	MessageInitiationSize      = 148                                           // size of handshake initiation message
    64  	MessageResponseSize        = 92                                            // size of response message
    65  	MessageCookieReplySize     = 64                                            // size of cookie reply message
    66  	MessageTransportHeaderSize = 16                                            // size of data preceding content in transport message
    67  	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
    68  	MessageKeepaliveSize       = MessageTransportSize                          // size of keepalive
    69  	MessageHandshakeSize       = MessageInitiationSize                         // size of largest handshake related message
    70  )
    71  
    72  const (
    73  	MessageTransportOffsetReceiver = 4
    74  	MessageTransportOffsetCounter  = 8
    75  	MessageTransportOffsetContent  = 16
    76  )
    77  
    78  var packetSizeToMsgType map[int]uint32
    79  
    80  var msgTypeToJunkSize map[uint32]int
    81  
    82  /* Type is an 8-bit field, followed by 3 nul bytes,
    83   * by marshalling the messages in little-endian byteorder
    84   * we can treat these as a 32-bit unsigned int (for now)
    85   *
    86   */
    87  
    88  type MessageInitiation struct {
    89  	Type      uint32
    90  	Sender    uint32
    91  	Ephemeral NoisePublicKey
    92  	Static    [NoisePublicKeySize + poly1305.TagSize]byte
    93  	Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
    94  	MAC1      [blake2s.Size128]byte
    95  	MAC2      [blake2s.Size128]byte
    96  }
    97  
    98  type MessageResponse struct {
    99  	Type      uint32
   100  	Sender    uint32
   101  	Receiver  uint32
   102  	Ephemeral NoisePublicKey
   103  	Empty     [poly1305.TagSize]byte
   104  	MAC1      [blake2s.Size128]byte
   105  	MAC2      [blake2s.Size128]byte
   106  }
   107  
   108  type MessageTransport struct {
   109  	Type     uint32
   110  	Receiver uint32
   111  	Counter  uint64
   112  	Content  []byte
   113  }
   114  
   115  type MessageCookieReply struct {
   116  	Type     uint32
   117  	Receiver uint32
   118  	Nonce    [chacha20poly1305.NonceSizeX]byte
   119  	Cookie   [blake2s.Size128 + poly1305.TagSize]byte
   120  }
   121  
   122  type Handshake struct {
   123  	state                     handshakeState
   124  	mutex                     sync.RWMutex
   125  	hash                      [blake2s.Size]byte       // hash value
   126  	chainKey                  [blake2s.Size]byte       // chain key
   127  	presharedKey              NoisePresharedKey        // psk
   128  	localEphemeral            NoisePrivateKey          // ephemeral secret key
   129  	localIndex                uint32                   // used to clear hash-table
   130  	remoteIndex               uint32                   // index for sending
   131  	remoteStatic              NoisePublicKey           // long term key
   132  	remoteEphemeral           NoisePublicKey           // ephemeral public key
   133  	precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
   134  	lastTimestamp             tai64n.Timestamp
   135  	lastInitiationConsumption time.Time
   136  	lastSentHandshake         time.Time
   137  }
   138  
   139  var (
   140  	InitialChainKey [blake2s.Size]byte
   141  	InitialHash     [blake2s.Size]byte
   142  	ZeroNonce       [chacha20poly1305.NonceSize]byte
   143  )
   144  
   145  func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
   146  	KDF1(dst, c[:], data)
   147  }
   148  
   149  func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
   150  	hash, _ := blake2s.New256(nil)
   151  	hash.Write(h[:])
   152  	hash.Write(data)
   153  	hash.Sum(dst[:0])
   154  	hash.Reset()
   155  }
   156  
   157  func (h *Handshake) Clear() {
   158  	setZero(h.localEphemeral[:])
   159  	setZero(h.remoteEphemeral[:])
   160  	setZero(h.chainKey[:])
   161  	setZero(h.hash[:])
   162  	h.localIndex = 0
   163  	h.state = handshakeZeroed
   164  }
   165  
   166  func (h *Handshake) mixHash(data []byte) {
   167  	mixHash(&h.hash, &h.hash, data)
   168  }
   169  
   170  func (h *Handshake) mixKey(data []byte) {
   171  	mixKey(&h.chainKey, &h.chainKey, data)
   172  }
   173  
   174  /* Do basic precomputations
   175   */
   176  func init() {
   177  	InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
   178  	mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
   179  }
   180  
   181  func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
   182  	device.staticIdentity.RLock()
   183  	defer device.staticIdentity.RUnlock()
   184  
   185  	handshake := &peer.handshake
   186  	handshake.mutex.Lock()
   187  	defer handshake.mutex.Unlock()
   188  
   189  	// create ephemeral key
   190  	var err error
   191  	handshake.hash = InitialHash
   192  	handshake.chainKey = InitialChainKey
   193  	handshake.localEphemeral, err = newPrivateKey()
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	handshake.mixHash(handshake.remoteStatic[:])
   199  
   200  	device.aSecMux.RLock()
   201  	msg := MessageInitiation{
   202  		Type:      MessageInitiationType,
   203  		Ephemeral: handshake.localEphemeral.publicKey(),
   204  	}
   205  	device.aSecMux.RUnlock()
   206  
   207  	handshake.mixKey(msg.Ephemeral[:])
   208  	handshake.mixHash(msg.Ephemeral[:])
   209  
   210  	// encrypt static key
   211  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	var key [chacha20poly1305.KeySize]byte
   216  	KDF2(
   217  		&handshake.chainKey,
   218  		&key,
   219  		handshake.chainKey[:],
   220  		ss[:],
   221  	)
   222  	aead, _ := chacha20poly1305.New(key[:])
   223  	aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
   224  	handshake.mixHash(msg.Static[:])
   225  
   226  	// encrypt timestamp
   227  	if isZero(handshake.precomputedStaticStatic[:]) {
   228  		return nil, errInvalidPublicKey
   229  	}
   230  	KDF2(
   231  		&handshake.chainKey,
   232  		&key,
   233  		handshake.chainKey[:],
   234  		handshake.precomputedStaticStatic[:],
   235  	)
   236  	timestamp := tai64n.Now()
   237  	aead, _ = chacha20poly1305.New(key[:])
   238  	aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
   239  
   240  	// assign index
   241  	device.indexTable.Delete(handshake.localIndex)
   242  	msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  	handshake.localIndex = msg.Sender
   247  
   248  	handshake.mixHash(msg.Timestamp[:])
   249  	handshake.state = handshakeInitiationCreated
   250  	return &msg, nil
   251  }
   252  
   253  func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
   254  	var (
   255  		hash     [blake2s.Size]byte
   256  		chainKey [blake2s.Size]byte
   257  	)
   258  
   259  	device.aSecMux.RLock()
   260  	if msg.Type != MessageInitiationType {
   261  		device.aSecMux.RUnlock()
   262  		return nil
   263  	}
   264  	device.aSecMux.RUnlock()
   265  
   266  	device.staticIdentity.RLock()
   267  	defer device.staticIdentity.RUnlock()
   268  
   269  	mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
   270  	mixHash(&hash, &hash, msg.Ephemeral[:])
   271  	mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
   272  
   273  	// decrypt static key
   274  	var peerPK NoisePublicKey
   275  	var key [chacha20poly1305.KeySize]byte
   276  	ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   277  	if err != nil {
   278  		return nil
   279  	}
   280  	KDF2(&chainKey, &key, chainKey[:], ss[:])
   281  	aead, _ := chacha20poly1305.New(key[:])
   282  	_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
   283  	if err != nil {
   284  		return nil
   285  	}
   286  	mixHash(&hash, &hash, msg.Static[:])
   287  
   288  	// lookup peer
   289  
   290  	peer := device.LookupPeer(peerPK)
   291  	if peer == nil || !peer.isRunning.Load() {
   292  		return nil
   293  	}
   294  
   295  	handshake := &peer.handshake
   296  
   297  	// verify identity
   298  
   299  	var timestamp tai64n.Timestamp
   300  
   301  	handshake.mutex.RLock()
   302  
   303  	if isZero(handshake.precomputedStaticStatic[:]) {
   304  		handshake.mutex.RUnlock()
   305  		return nil
   306  	}
   307  	KDF2(
   308  		&chainKey,
   309  		&key,
   310  		chainKey[:],
   311  		handshake.precomputedStaticStatic[:],
   312  	)
   313  	aead, _ = chacha20poly1305.New(key[:])
   314  	_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
   315  	if err != nil {
   316  		handshake.mutex.RUnlock()
   317  		return nil
   318  	}
   319  	mixHash(&hash, &hash, msg.Timestamp[:])
   320  
   321  	// protect against replay & flood
   322  
   323  	replay := !timestamp.After(handshake.lastTimestamp)
   324  	flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
   325  	handshake.mutex.RUnlock()
   326  	if replay {
   327  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
   328  		return nil
   329  	}
   330  	if flood {
   331  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
   332  		return nil
   333  	}
   334  
   335  	// update handshake state
   336  
   337  	handshake.mutex.Lock()
   338  
   339  	handshake.hash = hash
   340  	handshake.chainKey = chainKey
   341  	handshake.remoteIndex = msg.Sender
   342  	handshake.remoteEphemeral = msg.Ephemeral
   343  	if timestamp.After(handshake.lastTimestamp) {
   344  		handshake.lastTimestamp = timestamp
   345  	}
   346  	now := time.Now()
   347  	if now.After(handshake.lastInitiationConsumption) {
   348  		handshake.lastInitiationConsumption = now
   349  	}
   350  	handshake.state = handshakeInitiationConsumed
   351  
   352  	handshake.mutex.Unlock()
   353  
   354  	setZero(hash[:])
   355  	setZero(chainKey[:])
   356  
   357  	return peer
   358  }
   359  
   360  func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
   361  	handshake := &peer.handshake
   362  	handshake.mutex.Lock()
   363  	defer handshake.mutex.Unlock()
   364  
   365  	if handshake.state != handshakeInitiationConsumed {
   366  		return nil, errors.New("handshake initiation must be consumed first")
   367  	}
   368  
   369  	// assign index
   370  
   371  	var err error
   372  	device.indexTable.Delete(handshake.localIndex)
   373  	handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  
   378  	var msg MessageResponse
   379  	device.aSecMux.RLock()
   380  	msg.Type = MessageResponseType
   381  	device.aSecMux.RUnlock()
   382  	msg.Sender = handshake.localIndex
   383  	msg.Receiver = handshake.remoteIndex
   384  
   385  	// create ephemeral key
   386  
   387  	handshake.localEphemeral, err = newPrivateKey()
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  	msg.Ephemeral = handshake.localEphemeral.publicKey()
   392  	handshake.mixHash(msg.Ephemeral[:])
   393  	handshake.mixKey(msg.Ephemeral[:])
   394  
   395  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  	handshake.mixKey(ss[:])
   400  	ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  	handshake.mixKey(ss[:])
   405  
   406  	// add preshared key
   407  
   408  	var tau [blake2s.Size]byte
   409  	var key [chacha20poly1305.KeySize]byte
   410  
   411  	KDF3(
   412  		&handshake.chainKey,
   413  		&tau,
   414  		&key,
   415  		handshake.chainKey[:],
   416  		handshake.presharedKey[:],
   417  	)
   418  
   419  	handshake.mixHash(tau[:])
   420  
   421  	aead, _ := chacha20poly1305.New(key[:])
   422  	aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
   423  	handshake.mixHash(msg.Empty[:])
   424  
   425  	handshake.state = handshakeResponseCreated
   426  
   427  	return &msg, nil
   428  }
   429  
   430  func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
   431  	device.aSecMux.RLock()
   432  	if msg.Type != MessageResponseType {
   433  		device.aSecMux.RUnlock()
   434  		return nil
   435  	}
   436  	device.aSecMux.RUnlock()
   437  
   438  	// lookup handshake by receiver
   439  
   440  	lookup := device.indexTable.Lookup(msg.Receiver)
   441  	handshake := lookup.handshake
   442  	if handshake == nil {
   443  		return nil
   444  	}
   445  
   446  	var (
   447  		hash     [blake2s.Size]byte
   448  		chainKey [blake2s.Size]byte
   449  	)
   450  
   451  	ok := func() bool {
   452  		// lock handshake state
   453  
   454  		handshake.mutex.RLock()
   455  		defer handshake.mutex.RUnlock()
   456  
   457  		if handshake.state != handshakeInitiationCreated {
   458  			return false
   459  		}
   460  
   461  		// lock private key for reading
   462  
   463  		device.staticIdentity.RLock()
   464  		defer device.staticIdentity.RUnlock()
   465  
   466  		// finish 3-way DH
   467  
   468  		mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
   469  		mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
   470  
   471  		ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
   472  		if err != nil {
   473  			return false
   474  		}
   475  		mixKey(&chainKey, &chainKey, ss[:])
   476  		setZero(ss[:])
   477  
   478  		ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   479  		if err != nil {
   480  			return false
   481  		}
   482  		mixKey(&chainKey, &chainKey, ss[:])
   483  		setZero(ss[:])
   484  
   485  		// add preshared key (psk)
   486  
   487  		var tau [blake2s.Size]byte
   488  		var key [chacha20poly1305.KeySize]byte
   489  		KDF3(
   490  			&chainKey,
   491  			&tau,
   492  			&key,
   493  			chainKey[:],
   494  			handshake.presharedKey[:],
   495  		)
   496  		mixHash(&hash, &hash, tau[:])
   497  
   498  		// authenticate transcript
   499  
   500  		aead, _ := chacha20poly1305.New(key[:])
   501  		_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
   502  		if err != nil {
   503  			return false
   504  		}
   505  		mixHash(&hash, &hash, msg.Empty[:])
   506  		return true
   507  	}()
   508  
   509  	if !ok {
   510  		return nil
   511  	}
   512  
   513  	// update handshake state
   514  
   515  	handshake.mutex.Lock()
   516  
   517  	handshake.hash = hash
   518  	handshake.chainKey = chainKey
   519  	handshake.remoteIndex = msg.Sender
   520  	handshake.state = handshakeResponseConsumed
   521  
   522  	handshake.mutex.Unlock()
   523  
   524  	setZero(hash[:])
   525  	setZero(chainKey[:])
   526  
   527  	return lookup.peer
   528  }
   529  
   530  /* Derives a new keypair from the current handshake state
   531   *
   532   */
   533  func (peer *Peer) BeginSymmetricSession() error {
   534  	device := peer.device
   535  	handshake := &peer.handshake
   536  	handshake.mutex.Lock()
   537  	defer handshake.mutex.Unlock()
   538  
   539  	// derive keys
   540  
   541  	var isInitiator bool
   542  	var sendKey [chacha20poly1305.KeySize]byte
   543  	var recvKey [chacha20poly1305.KeySize]byte
   544  
   545  	if handshake.state == handshakeResponseConsumed {
   546  		KDF2(
   547  			&sendKey,
   548  			&recvKey,
   549  			handshake.chainKey[:],
   550  			nil,
   551  		)
   552  		isInitiator = true
   553  	} else if handshake.state == handshakeResponseCreated {
   554  		KDF2(
   555  			&recvKey,
   556  			&sendKey,
   557  			handshake.chainKey[:],
   558  			nil,
   559  		)
   560  		isInitiator = false
   561  	} else {
   562  		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
   563  	}
   564  
   565  	// zero handshake
   566  
   567  	setZero(handshake.chainKey[:])
   568  	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
   569  	setZero(handshake.localEphemeral[:])
   570  	peer.handshake.state = handshakeZeroed
   571  
   572  	// create AEAD instances
   573  
   574  	keypair := new(Keypair)
   575  	keypair.send, _ = chacha20poly1305.New(sendKey[:])
   576  	keypair.receive, _ = chacha20poly1305.New(recvKey[:])
   577  
   578  	setZero(sendKey[:])
   579  	setZero(recvKey[:])
   580  
   581  	keypair.created = time.Now()
   582  	keypair.replayFilter.Reset()
   583  	keypair.isInitiator = isInitiator
   584  	keypair.localIndex = peer.handshake.localIndex
   585  	keypair.remoteIndex = peer.handshake.remoteIndex
   586  
   587  	// remap index
   588  
   589  	device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
   590  	handshake.localIndex = 0
   591  
   592  	// rotate key pairs
   593  
   594  	keypairs := &peer.keypairs
   595  	keypairs.Lock()
   596  	defer keypairs.Unlock()
   597  
   598  	previous := keypairs.previous
   599  	next := keypairs.next.Load()
   600  	current := keypairs.current
   601  
   602  	if isInitiator {
   603  		if next != nil {
   604  			keypairs.next.Store(nil)
   605  			keypairs.previous = next
   606  			device.DeleteKeypair(current)
   607  		} else {
   608  			keypairs.previous = current
   609  		}
   610  		device.DeleteKeypair(previous)
   611  		keypairs.current = keypair
   612  	} else {
   613  		keypairs.next.Store(keypair)
   614  		device.DeleteKeypair(next)
   615  		keypairs.previous = nil
   616  		device.DeleteKeypair(previous)
   617  	}
   618  
   619  	return nil
   620  }
   621  
   622  func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
   623  	keypairs := &peer.keypairs
   624  
   625  	if keypairs.next.Load() != receivedKeypair {
   626  		return false
   627  	}
   628  	keypairs.Lock()
   629  	defer keypairs.Unlock()
   630  	if keypairs.next.Load() != receivedKeypair {
   631  		return false
   632  	}
   633  	old := keypairs.previous
   634  	keypairs.previous = keypairs.current
   635  	peer.device.DeleteKeypair(old)
   636  	keypairs.current = keypairs.next.Load()
   637  	keypairs.next.Store(nil)
   638  	return true
   639  }