github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/noise-protocol.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/sagernet/wireguard-go/tai64n"
    15  	"golang.org/x/crypto/blake2s"
    16  	"golang.org/x/crypto/chacha20poly1305"
    17  	"golang.org/x/crypto/poly1305"
    18  )
    19  
    20  type handshakeState int
    21  
    22  const (
    23  	handshakeZeroed = handshakeState(iota)
    24  	handshakeInitiationCreated
    25  	handshakeInitiationConsumed
    26  	handshakeResponseCreated
    27  	handshakeResponseConsumed
    28  )
    29  
    30  func (hs handshakeState) String() string {
    31  	switch hs {
    32  	case handshakeZeroed:
    33  		return "handshakeZeroed"
    34  	case handshakeInitiationCreated:
    35  		return "handshakeInitiationCreated"
    36  	case handshakeInitiationConsumed:
    37  		return "handshakeInitiationConsumed"
    38  	case handshakeResponseCreated:
    39  		return "handshakeResponseCreated"
    40  	case handshakeResponseConsumed:
    41  		return "handshakeResponseConsumed"
    42  	default:
    43  		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
    44  	}
    45  }
    46  
    47  const (
    48  	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
    49  	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
    50  	WGLabelMAC1       = "mac1----"
    51  	WGLabelCookie     = "cookie--"
    52  )
    53  
    54  const (
    55  	MessageInitiationType  = 1
    56  	MessageResponseType    = 2
    57  	MessageCookieReplyType = 3
    58  	MessageTransportType   = 4
    59  )
    60  
    61  const (
    62  	MessageInitiationSize      = 148                                           // size of handshake initiation message
    63  	MessageResponseSize        = 92                                            // size of response message
    64  	MessageCookieReplySize     = 64                                            // size of cookie reply message
    65  	MessageTransportHeaderSize = 16                                            // size of data preceding content in transport message
    66  	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
    67  	MessageKeepaliveSize       = MessageTransportSize                          // size of keepalive
    68  	MessageHandshakeSize       = MessageInitiationSize                         // size of largest handshake related message
    69  )
    70  
    71  const (
    72  	MessageTransportOffsetReceiver = 4
    73  	MessageTransportOffsetCounter  = 8
    74  	MessageTransportOffsetContent  = 16
    75  )
    76  
    77  /* Type is an 8-bit field, followed by 3 nul bytes,
    78   * by marshalling the messages in little-endian byteorder
    79   * we can treat these as a 32-bit unsigned int (for now)
    80   *
    81   */
    82  
    83  type MessageInitiation struct {
    84  	Type      uint32
    85  	Sender    uint32
    86  	Ephemeral NoisePublicKey
    87  	Static    [NoisePublicKeySize + poly1305.TagSize]byte
    88  	Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
    89  	MAC1      [blake2s.Size128]byte
    90  	MAC2      [blake2s.Size128]byte
    91  }
    92  
    93  type MessageResponse struct {
    94  	Type      uint32
    95  	Sender    uint32
    96  	Receiver  uint32
    97  	Ephemeral NoisePublicKey
    98  	Empty     [poly1305.TagSize]byte
    99  	MAC1      [blake2s.Size128]byte
   100  	MAC2      [blake2s.Size128]byte
   101  }
   102  
   103  type MessageTransport struct {
   104  	Type     uint32
   105  	Receiver uint32
   106  	Counter  uint64
   107  	Content  []byte
   108  }
   109  
   110  type MessageCookieReply struct {
   111  	Type     uint32
   112  	Receiver uint32
   113  	Nonce    [chacha20poly1305.NonceSizeX]byte
   114  	Cookie   [blake2s.Size128 + poly1305.TagSize]byte
   115  }
   116  
   117  type Handshake struct {
   118  	state                     handshakeState
   119  	mutex                     sync.RWMutex
   120  	hash                      [blake2s.Size]byte       // hash value
   121  	chainKey                  [blake2s.Size]byte       // chain key
   122  	presharedKey              NoisePresharedKey        // psk
   123  	localEphemeral            NoisePrivateKey          // ephemeral secret key
   124  	localIndex                uint32                   // used to clear hash-table
   125  	remoteIndex               uint32                   // index for sending
   126  	remoteStatic              NoisePublicKey           // long term key
   127  	remoteEphemeral           NoisePublicKey           // ephemeral public key
   128  	precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
   129  	lastTimestamp             tai64n.Timestamp
   130  	lastInitiationConsumption time.Time
   131  	lastSentHandshake         time.Time
   132  }
   133  
   134  var (
   135  	InitialChainKey [blake2s.Size]byte
   136  	InitialHash     [blake2s.Size]byte
   137  	ZeroNonce       [chacha20poly1305.NonceSize]byte
   138  )
   139  
   140  func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
   141  	KDF1(dst, c[:], data)
   142  }
   143  
   144  func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
   145  	hash, _ := blake2s.New256(nil)
   146  	hash.Write(h[:])
   147  	hash.Write(data)
   148  	hash.Sum(dst[:0])
   149  	hash.Reset()
   150  }
   151  
   152  func (h *Handshake) Clear() {
   153  	setZero(h.localEphemeral[:])
   154  	setZero(h.remoteEphemeral[:])
   155  	setZero(h.chainKey[:])
   156  	setZero(h.hash[:])
   157  	h.localIndex = 0
   158  	h.state = handshakeZeroed
   159  }
   160  
   161  func (h *Handshake) mixHash(data []byte) {
   162  	mixHash(&h.hash, &h.hash, data)
   163  }
   164  
   165  func (h *Handshake) mixKey(data []byte) {
   166  	mixKey(&h.chainKey, &h.chainKey, data)
   167  }
   168  
   169  /* Do basic precomputations
   170   */
   171  func init() {
   172  	InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
   173  	mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
   174  }
   175  
   176  func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
   177  	device.staticIdentity.RLock()
   178  	defer device.staticIdentity.RUnlock()
   179  
   180  	handshake := &peer.handshake
   181  	handshake.mutex.Lock()
   182  	defer handshake.mutex.Unlock()
   183  
   184  	// create ephemeral key
   185  	var err error
   186  	handshake.hash = InitialHash
   187  	handshake.chainKey = InitialChainKey
   188  	handshake.localEphemeral, err = newPrivateKey()
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	handshake.mixHash(handshake.remoteStatic[:])
   194  
   195  	msg := MessageInitiation{
   196  		Type:      MessageInitiationType,
   197  		Ephemeral: handshake.localEphemeral.publicKey(),
   198  	}
   199  
   200  	handshake.mixKey(msg.Ephemeral[:])
   201  	handshake.mixHash(msg.Ephemeral[:])
   202  
   203  	// encrypt static key
   204  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	var key [chacha20poly1305.KeySize]byte
   209  	KDF2(
   210  		&handshake.chainKey,
   211  		&key,
   212  		handshake.chainKey[:],
   213  		ss[:],
   214  	)
   215  	aead, _ := chacha20poly1305.New(key[:])
   216  	aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
   217  	handshake.mixHash(msg.Static[:])
   218  
   219  	// encrypt timestamp
   220  	if isZero(handshake.precomputedStaticStatic[:]) {
   221  		return nil, errInvalidPublicKey
   222  	}
   223  	KDF2(
   224  		&handshake.chainKey,
   225  		&key,
   226  		handshake.chainKey[:],
   227  		handshake.precomputedStaticStatic[:],
   228  	)
   229  	timestamp := tai64n.Now()
   230  	aead, _ = chacha20poly1305.New(key[:])
   231  	aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
   232  
   233  	// assign index
   234  	device.indexTable.Delete(handshake.localIndex)
   235  	msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	handshake.localIndex = msg.Sender
   240  
   241  	handshake.mixHash(msg.Timestamp[:])
   242  	handshake.state = handshakeInitiationCreated
   243  	return &msg, nil
   244  }
   245  
   246  func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
   247  	var (
   248  		hash     [blake2s.Size]byte
   249  		chainKey [blake2s.Size]byte
   250  	)
   251  
   252  	if msg.Type != MessageInitiationType {
   253  		return nil
   254  	}
   255  
   256  	device.staticIdentity.RLock()
   257  	defer device.staticIdentity.RUnlock()
   258  
   259  	mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
   260  	mixHash(&hash, &hash, msg.Ephemeral[:])
   261  	mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
   262  
   263  	// decrypt static key
   264  	var peerPK NoisePublicKey
   265  	var key [chacha20poly1305.KeySize]byte
   266  	ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   267  	if err != nil {
   268  		return nil
   269  	}
   270  	KDF2(&chainKey, &key, chainKey[:], ss[:])
   271  	aead, _ := chacha20poly1305.New(key[:])
   272  	_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
   273  	if err != nil {
   274  		return nil
   275  	}
   276  	mixHash(&hash, &hash, msg.Static[:])
   277  
   278  	// lookup peer
   279  
   280  	peer := device.LookupPeer(peerPK)
   281  	if peer == nil || !peer.isRunning.Load() {
   282  		return nil
   283  	}
   284  
   285  	handshake := &peer.handshake
   286  
   287  	// verify identity
   288  
   289  	var timestamp tai64n.Timestamp
   290  
   291  	handshake.mutex.RLock()
   292  
   293  	if isZero(handshake.precomputedStaticStatic[:]) {
   294  		handshake.mutex.RUnlock()
   295  		return nil
   296  	}
   297  	KDF2(
   298  		&chainKey,
   299  		&key,
   300  		chainKey[:],
   301  		handshake.precomputedStaticStatic[:],
   302  	)
   303  	aead, _ = chacha20poly1305.New(key[:])
   304  	_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
   305  	if err != nil {
   306  		handshake.mutex.RUnlock()
   307  		return nil
   308  	}
   309  	mixHash(&hash, &hash, msg.Timestamp[:])
   310  
   311  	// protect against replay & flood
   312  
   313  	replay := !timestamp.After(handshake.lastTimestamp)
   314  	flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
   315  	handshake.mutex.RUnlock()
   316  	if replay {
   317  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
   318  		return nil
   319  	}
   320  	if flood {
   321  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
   322  		return nil
   323  	}
   324  
   325  	// update handshake state
   326  
   327  	handshake.mutex.Lock()
   328  
   329  	handshake.hash = hash
   330  	handshake.chainKey = chainKey
   331  	handshake.remoteIndex = msg.Sender
   332  	handshake.remoteEphemeral = msg.Ephemeral
   333  	if timestamp.After(handshake.lastTimestamp) {
   334  		handshake.lastTimestamp = timestamp
   335  	}
   336  	now := time.Now()
   337  	if now.After(handshake.lastInitiationConsumption) {
   338  		handshake.lastInitiationConsumption = now
   339  	}
   340  	handshake.state = handshakeInitiationConsumed
   341  
   342  	handshake.mutex.Unlock()
   343  
   344  	setZero(hash[:])
   345  	setZero(chainKey[:])
   346  
   347  	return peer
   348  }
   349  
   350  func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
   351  	handshake := &peer.handshake
   352  	handshake.mutex.Lock()
   353  	defer handshake.mutex.Unlock()
   354  
   355  	if handshake.state != handshakeInitiationConsumed {
   356  		return nil, errors.New("handshake initiation must be consumed first")
   357  	}
   358  
   359  	// assign index
   360  
   361  	var err error
   362  	device.indexTable.Delete(handshake.localIndex)
   363  	handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	var msg MessageResponse
   369  	msg.Type = MessageResponseType
   370  	msg.Sender = handshake.localIndex
   371  	msg.Receiver = handshake.remoteIndex
   372  
   373  	// create ephemeral key
   374  
   375  	handshake.localEphemeral, err = newPrivateKey()
   376  	if err != nil {
   377  		return nil, err
   378  	}
   379  	msg.Ephemeral = handshake.localEphemeral.publicKey()
   380  	handshake.mixHash(msg.Ephemeral[:])
   381  	handshake.mixKey(msg.Ephemeral[:])
   382  
   383  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  	handshake.mixKey(ss[:])
   388  	ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   389  	if err != nil {
   390  		return nil, err
   391  	}
   392  	handshake.mixKey(ss[:])
   393  
   394  	// add preshared key
   395  
   396  	var tau [blake2s.Size]byte
   397  	var key [chacha20poly1305.KeySize]byte
   398  
   399  	KDF3(
   400  		&handshake.chainKey,
   401  		&tau,
   402  		&key,
   403  		handshake.chainKey[:],
   404  		handshake.presharedKey[:],
   405  	)
   406  
   407  	handshake.mixHash(tau[:])
   408  
   409  	aead, _ := chacha20poly1305.New(key[:])
   410  	aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
   411  	handshake.mixHash(msg.Empty[:])
   412  
   413  	handshake.state = handshakeResponseCreated
   414  
   415  	return &msg, nil
   416  }
   417  
   418  func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
   419  	if msg.Type != MessageResponseType {
   420  		return nil
   421  	}
   422  
   423  	// lookup handshake by receiver
   424  
   425  	lookup := device.indexTable.Lookup(msg.Receiver)
   426  	handshake := lookup.handshake
   427  	if handshake == nil {
   428  		return nil
   429  	}
   430  
   431  	var (
   432  		hash     [blake2s.Size]byte
   433  		chainKey [blake2s.Size]byte
   434  	)
   435  
   436  	ok := func() bool {
   437  		// lock handshake state
   438  
   439  		handshake.mutex.RLock()
   440  		defer handshake.mutex.RUnlock()
   441  
   442  		if handshake.state != handshakeInitiationCreated {
   443  			return false
   444  		}
   445  
   446  		// lock private key for reading
   447  
   448  		device.staticIdentity.RLock()
   449  		defer device.staticIdentity.RUnlock()
   450  
   451  		// finish 3-way DH
   452  
   453  		mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
   454  		mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
   455  
   456  		ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
   457  		if err != nil {
   458  			return false
   459  		}
   460  		mixKey(&chainKey, &chainKey, ss[:])
   461  		setZero(ss[:])
   462  
   463  		ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   464  		if err != nil {
   465  			return false
   466  		}
   467  		mixKey(&chainKey, &chainKey, ss[:])
   468  		setZero(ss[:])
   469  
   470  		// add preshared key (psk)
   471  
   472  		var tau [blake2s.Size]byte
   473  		var key [chacha20poly1305.KeySize]byte
   474  		KDF3(
   475  			&chainKey,
   476  			&tau,
   477  			&key,
   478  			chainKey[:],
   479  			handshake.presharedKey[:],
   480  		)
   481  		mixHash(&hash, &hash, tau[:])
   482  
   483  		// authenticate transcript
   484  
   485  		aead, _ := chacha20poly1305.New(key[:])
   486  		_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
   487  		if err != nil {
   488  			return false
   489  		}
   490  		mixHash(&hash, &hash, msg.Empty[:])
   491  		return true
   492  	}()
   493  
   494  	if !ok {
   495  		return nil
   496  	}
   497  
   498  	// update handshake state
   499  
   500  	handshake.mutex.Lock()
   501  
   502  	handshake.hash = hash
   503  	handshake.chainKey = chainKey
   504  	handshake.remoteIndex = msg.Sender
   505  	handshake.state = handshakeResponseConsumed
   506  
   507  	handshake.mutex.Unlock()
   508  
   509  	setZero(hash[:])
   510  	setZero(chainKey[:])
   511  
   512  	return lookup.peer
   513  }
   514  
   515  /* Derives a new keypair from the current handshake state
   516   *
   517   */
   518  func (peer *Peer) BeginSymmetricSession() error {
   519  	device := peer.device
   520  	handshake := &peer.handshake
   521  	handshake.mutex.Lock()
   522  	defer handshake.mutex.Unlock()
   523  
   524  	// derive keys
   525  
   526  	var isInitiator bool
   527  	var sendKey [chacha20poly1305.KeySize]byte
   528  	var recvKey [chacha20poly1305.KeySize]byte
   529  
   530  	if handshake.state == handshakeResponseConsumed {
   531  		KDF2(
   532  			&sendKey,
   533  			&recvKey,
   534  			handshake.chainKey[:],
   535  			nil,
   536  		)
   537  		isInitiator = true
   538  	} else if handshake.state == handshakeResponseCreated {
   539  		KDF2(
   540  			&recvKey,
   541  			&sendKey,
   542  			handshake.chainKey[:],
   543  			nil,
   544  		)
   545  		isInitiator = false
   546  	} else {
   547  		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
   548  	}
   549  
   550  	// zero handshake
   551  
   552  	setZero(handshake.chainKey[:])
   553  	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
   554  	setZero(handshake.localEphemeral[:])
   555  	peer.handshake.state = handshakeZeroed
   556  
   557  	// create AEAD instances
   558  
   559  	keypair := new(Keypair)
   560  	keypair.send, _ = chacha20poly1305.New(sendKey[:])
   561  	keypair.receive, _ = chacha20poly1305.New(recvKey[:])
   562  
   563  	setZero(sendKey[:])
   564  	setZero(recvKey[:])
   565  
   566  	keypair.created = time.Now()
   567  	keypair.replayFilter.Reset()
   568  	keypair.isInitiator = isInitiator
   569  	keypair.localIndex = peer.handshake.localIndex
   570  	keypair.remoteIndex = peer.handshake.remoteIndex
   571  
   572  	// remap index
   573  
   574  	device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
   575  	handshake.localIndex = 0
   576  
   577  	// rotate key pairs
   578  
   579  	keypairs := &peer.keypairs
   580  	keypairs.Lock()
   581  	defer keypairs.Unlock()
   582  
   583  	previous := keypairs.previous
   584  	next := keypairs.next.Load()
   585  	current := keypairs.current
   586  
   587  	if isInitiator {
   588  		if next != nil {
   589  			keypairs.next.Store(nil)
   590  			keypairs.previous = next
   591  			device.DeleteKeypair(current)
   592  		} else {
   593  			keypairs.previous = current
   594  		}
   595  		device.DeleteKeypair(previous)
   596  		keypairs.current = keypair
   597  	} else {
   598  		keypairs.next.Store(keypair)
   599  		device.DeleteKeypair(next)
   600  		keypairs.previous = nil
   601  		device.DeleteKeypair(previous)
   602  	}
   603  
   604  	return nil
   605  }
   606  
   607  func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
   608  	keypairs := &peer.keypairs
   609  
   610  	if keypairs.next.Load() != receivedKeypair {
   611  		return false
   612  	}
   613  	keypairs.Lock()
   614  	defer keypairs.Unlock()
   615  	if keypairs.next.Load() != receivedKeypair {
   616  		return false
   617  	}
   618  	old := keypairs.previous
   619  	keypairs.previous = keypairs.current
   620  	peer.device.DeleteKeypair(old)
   621  	keypairs.current = keypairs.next.Load()
   622  	keypairs.next.Store(nil)
   623  	return true
   624  }