github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/noise-protocol.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  	"golang.org/x/crypto/poly1305"
    17  
    18  	"github.com/liloew/wireguard-go/tai64n"
    19  )
    20  
    21  type handshakeState int
    22  
    23  const (
    24  	handshakeZeroed = handshakeState(iota)
    25  	handshakeInitiationCreated
    26  	handshakeInitiationConsumed
    27  	handshakeResponseCreated
    28  	handshakeResponseConsumed
    29  )
    30  
    31  func (hs handshakeState) String() string {
    32  	switch hs {
    33  	case handshakeZeroed:
    34  		return "handshakeZeroed"
    35  	case handshakeInitiationCreated:
    36  		return "handshakeInitiationCreated"
    37  	case handshakeInitiationConsumed:
    38  		return "handshakeInitiationConsumed"
    39  	case handshakeResponseCreated:
    40  		return "handshakeResponseCreated"
    41  	case handshakeResponseConsumed:
    42  		return "handshakeResponseConsumed"
    43  	default:
    44  		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
    45  	}
    46  }
    47  
    48  const (
    49  	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
    50  	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
    51  	WGLabelMAC1       = "mac1----"
    52  	WGLabelCookie     = "cookie--"
    53  )
    54  
    55  const (
    56  	MessageInitiationType  = 1
    57  	MessageResponseType    = 2
    58  	MessageCookieReplyType = 3
    59  	MessageTransportType   = 4
    60  )
    61  
    62  const (
    63  	MessageInitiationSize      = 148                                           // size of handshake initiation message
    64  	MessageResponseSize        = 92                                            // size of response message
    65  	MessageCookieReplySize     = 64                                            // size of cookie reply message
    66  	MessageTransportHeaderSize = 16                                            // size of data preceding content in transport message
    67  	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
    68  	MessageKeepaliveSize       = MessageTransportSize                          // size of keepalive
    69  	MessageHandshakeSize       = MessageInitiationSize                         // size of largest handshake related message
    70  )
    71  
    72  const (
    73  	MessageTransportOffsetReceiver = 4
    74  	MessageTransportOffsetCounter  = 8
    75  	MessageTransportOffsetContent  = 16
    76  )
    77  
    78  /* Type is an 8-bit field, followed by 3 nul bytes,
    79   * by marshalling the messages in little-endian byteorder
    80   * we can treat these as a 32-bit unsigned int (for now)
    81   *
    82   */
    83  
    84  type MessageInitiation struct {
    85  	Type      uint32
    86  	Sender    uint32
    87  	Ephemeral NoisePublicKey
    88  	Static    [NoisePublicKeySize + poly1305.TagSize]byte
    89  	Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
    90  	MAC1      [blake2s.Size128]byte
    91  	MAC2      [blake2s.Size128]byte
    92  }
    93  
    94  type MessageResponse struct {
    95  	Type      uint32
    96  	Sender    uint32
    97  	Receiver  uint32
    98  	Ephemeral NoisePublicKey
    99  	Empty     [poly1305.TagSize]byte
   100  	MAC1      [blake2s.Size128]byte
   101  	MAC2      [blake2s.Size128]byte
   102  }
   103  
   104  type MessageTransport struct {
   105  	Type     uint32
   106  	Receiver uint32
   107  	Counter  uint64
   108  	Content  []byte
   109  }
   110  
   111  type MessageCookieReply struct {
   112  	Type     uint32
   113  	Receiver uint32
   114  	Nonce    [chacha20poly1305.NonceSizeX]byte
   115  	Cookie   [blake2s.Size128 + poly1305.TagSize]byte
   116  }
   117  
   118  type Handshake struct {
   119  	state                     handshakeState
   120  	mutex                     sync.RWMutex
   121  	hash                      [blake2s.Size]byte       // hash value
   122  	chainKey                  [blake2s.Size]byte       // chain key
   123  	presharedKey              NoisePresharedKey        // psk
   124  	localEphemeral            NoisePrivateKey          // ephemeral secret key
   125  	localIndex                uint32                   // used to clear hash-table
   126  	remoteIndex               uint32                   // index for sending
   127  	remoteStatic              NoisePublicKey           // long term key
   128  	remoteEphemeral           NoisePublicKey           // ephemeral public key
   129  	precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
   130  	lastTimestamp             tai64n.Timestamp
   131  	lastInitiationConsumption time.Time
   132  	lastSentHandshake         time.Time
   133  }
   134  
   135  var (
   136  	InitialChainKey [blake2s.Size]byte
   137  	InitialHash     [blake2s.Size]byte
   138  	ZeroNonce       [chacha20poly1305.NonceSize]byte
   139  )
   140  
   141  func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
   142  	KDF1(dst, c[:], data)
   143  }
   144  
   145  func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
   146  	hash, _ := blake2s.New256(nil)
   147  	hash.Write(h[:])
   148  	hash.Write(data)
   149  	hash.Sum(dst[:0])
   150  	hash.Reset()
   151  }
   152  
   153  func (h *Handshake) Clear() {
   154  	setZero(h.localEphemeral[:])
   155  	setZero(h.remoteEphemeral[:])
   156  	setZero(h.chainKey[:])
   157  	setZero(h.hash[:])
   158  	h.localIndex = 0
   159  	h.state = handshakeZeroed
   160  }
   161  
   162  func (h *Handshake) mixHash(data []byte) {
   163  	mixHash(&h.hash, &h.hash, data)
   164  }
   165  
   166  func (h *Handshake) mixKey(data []byte) {
   167  	mixKey(&h.chainKey, &h.chainKey, data)
   168  }
   169  
   170  /* Do basic precomputations
   171   */
   172  func init() {
   173  	InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
   174  	mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
   175  }
   176  
   177  func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
   178  	errZeroECDHResult := errors.New("ECDH returned all zeros")
   179  
   180  	device.staticIdentity.RLock()
   181  	defer device.staticIdentity.RUnlock()
   182  
   183  	handshake := &peer.handshake
   184  	handshake.mutex.Lock()
   185  	defer handshake.mutex.Unlock()
   186  
   187  	// create ephemeral key
   188  	var err error
   189  	handshake.hash = InitialHash
   190  	handshake.chainKey = InitialChainKey
   191  	handshake.localEphemeral, err = newPrivateKey()
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  
   196  	handshake.mixHash(handshake.remoteStatic[:])
   197  
   198  	msg := MessageInitiation{
   199  		Type:      MessageInitiationType,
   200  		Ephemeral: handshake.localEphemeral.publicKey(),
   201  	}
   202  
   203  	handshake.mixKey(msg.Ephemeral[:])
   204  	handshake.mixHash(msg.Ephemeral[:])
   205  
   206  	// encrypt static key
   207  	ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   208  	if isZero(ss[:]) {
   209  		return nil, errZeroECDHResult
   210  	}
   211  	var key [chacha20poly1305.KeySize]byte
   212  	KDF2(
   213  		&handshake.chainKey,
   214  		&key,
   215  		handshake.chainKey[:],
   216  		ss[:],
   217  	)
   218  	aead, _ := chacha20poly1305.New(key[:])
   219  	aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
   220  	handshake.mixHash(msg.Static[:])
   221  
   222  	// encrypt timestamp
   223  	if isZero(handshake.precomputedStaticStatic[:]) {
   224  		return nil, errZeroECDHResult
   225  	}
   226  	KDF2(
   227  		&handshake.chainKey,
   228  		&key,
   229  		handshake.chainKey[:],
   230  		handshake.precomputedStaticStatic[:],
   231  	)
   232  	timestamp := tai64n.Now()
   233  	aead, _ = chacha20poly1305.New(key[:])
   234  	aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
   235  
   236  	// assign index
   237  	device.indexTable.Delete(handshake.localIndex)
   238  	msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	handshake.localIndex = msg.Sender
   243  
   244  	handshake.mixHash(msg.Timestamp[:])
   245  	handshake.state = handshakeInitiationCreated
   246  	return &msg, nil
   247  }
   248  
   249  func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
   250  	var (
   251  		hash     [blake2s.Size]byte
   252  		chainKey [blake2s.Size]byte
   253  	)
   254  
   255  	if msg.Type != MessageInitiationType {
   256  		return nil
   257  	}
   258  
   259  	device.staticIdentity.RLock()
   260  	defer device.staticIdentity.RUnlock()
   261  
   262  	mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
   263  	mixHash(&hash, &hash, msg.Ephemeral[:])
   264  	mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
   265  
   266  	// decrypt static key
   267  	var err error
   268  	var peerPK NoisePublicKey
   269  	var key [chacha20poly1305.KeySize]byte
   270  	ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   271  	if isZero(ss[:]) {
   272  		return nil
   273  	}
   274  	KDF2(&chainKey, &key, chainKey[:], ss[:])
   275  	aead, _ := chacha20poly1305.New(key[:])
   276  	_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
   277  	if err != nil {
   278  		return nil
   279  	}
   280  	mixHash(&hash, &hash, msg.Static[:])
   281  
   282  	// lookup peer
   283  
   284  	peer := device.LookupPeer(peerPK)
   285  	if peer == nil || !peer.isRunning.Get() {
   286  		return nil
   287  	}
   288  
   289  	handshake := &peer.handshake
   290  
   291  	// verify identity
   292  
   293  	var timestamp tai64n.Timestamp
   294  
   295  	handshake.mutex.RLock()
   296  
   297  	if isZero(handshake.precomputedStaticStatic[:]) {
   298  		handshake.mutex.RUnlock()
   299  		return nil
   300  	}
   301  	KDF2(
   302  		&chainKey,
   303  		&key,
   304  		chainKey[:],
   305  		handshake.precomputedStaticStatic[:],
   306  	)
   307  	aead, _ = chacha20poly1305.New(key[:])
   308  	_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
   309  	if err != nil {
   310  		handshake.mutex.RUnlock()
   311  		return nil
   312  	}
   313  	mixHash(&hash, &hash, msg.Timestamp[:])
   314  
   315  	// protect against replay & flood
   316  
   317  	replay := !timestamp.After(handshake.lastTimestamp)
   318  	flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
   319  	handshake.mutex.RUnlock()
   320  	if replay {
   321  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
   322  		return nil
   323  	}
   324  	if flood {
   325  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
   326  		return nil
   327  	}
   328  
   329  	// update handshake state
   330  
   331  	handshake.mutex.Lock()
   332  
   333  	handshake.hash = hash
   334  	handshake.chainKey = chainKey
   335  	handshake.remoteIndex = msg.Sender
   336  	handshake.remoteEphemeral = msg.Ephemeral
   337  	if timestamp.After(handshake.lastTimestamp) {
   338  		handshake.lastTimestamp = timestamp
   339  	}
   340  	now := time.Now()
   341  	if now.After(handshake.lastInitiationConsumption) {
   342  		handshake.lastInitiationConsumption = now
   343  	}
   344  	handshake.state = handshakeInitiationConsumed
   345  
   346  	handshake.mutex.Unlock()
   347  
   348  	setZero(hash[:])
   349  	setZero(chainKey[:])
   350  
   351  	return peer
   352  }
   353  
   354  func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
   355  	handshake := &peer.handshake
   356  	handshake.mutex.Lock()
   357  	defer handshake.mutex.Unlock()
   358  
   359  	if handshake.state != handshakeInitiationConsumed {
   360  		return nil, errors.New("handshake initiation must be consumed first")
   361  	}
   362  
   363  	// assign index
   364  
   365  	var err error
   366  	device.indexTable.Delete(handshake.localIndex)
   367  	handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	var msg MessageResponse
   373  	msg.Type = MessageResponseType
   374  	msg.Sender = handshake.localIndex
   375  	msg.Receiver = handshake.remoteIndex
   376  
   377  	// create ephemeral key
   378  
   379  	handshake.localEphemeral, err = newPrivateKey()
   380  	if err != nil {
   381  		return nil, err
   382  	}
   383  	msg.Ephemeral = handshake.localEphemeral.publicKey()
   384  	handshake.mixHash(msg.Ephemeral[:])
   385  	handshake.mixKey(msg.Ephemeral[:])
   386  
   387  	func() {
   388  		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
   389  		handshake.mixKey(ss[:])
   390  		ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   391  		handshake.mixKey(ss[:])
   392  	}()
   393  
   394  	// add preshared key
   395  
   396  	var tau [blake2s.Size]byte
   397  	var key [chacha20poly1305.KeySize]byte
   398  
   399  	KDF3(
   400  		&handshake.chainKey,
   401  		&tau,
   402  		&key,
   403  		handshake.chainKey[:],
   404  		handshake.presharedKey[:],
   405  	)
   406  
   407  	handshake.mixHash(tau[:])
   408  
   409  	func() {
   410  		aead, _ := chacha20poly1305.New(key[:])
   411  		aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
   412  		handshake.mixHash(msg.Empty[:])
   413  	}()
   414  
   415  	handshake.state = handshakeResponseCreated
   416  
   417  	return &msg, nil
   418  }
   419  
   420  func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
   421  	if msg.Type != MessageResponseType {
   422  		return nil
   423  	}
   424  
   425  	// lookup handshake by receiver
   426  
   427  	lookup := device.indexTable.Lookup(msg.Receiver)
   428  	handshake := lookup.handshake
   429  	if handshake == nil {
   430  		return nil
   431  	}
   432  
   433  	var (
   434  		hash     [blake2s.Size]byte
   435  		chainKey [blake2s.Size]byte
   436  	)
   437  
   438  	ok := func() bool {
   439  		// lock handshake state
   440  
   441  		handshake.mutex.RLock()
   442  		defer handshake.mutex.RUnlock()
   443  
   444  		if handshake.state != handshakeInitiationCreated {
   445  			return false
   446  		}
   447  
   448  		// lock private key for reading
   449  
   450  		device.staticIdentity.RLock()
   451  		defer device.staticIdentity.RUnlock()
   452  
   453  		// finish 3-way DH
   454  
   455  		mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
   456  		mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
   457  
   458  		func() {
   459  			ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
   460  			mixKey(&chainKey, &chainKey, ss[:])
   461  			setZero(ss[:])
   462  		}()
   463  
   464  		func() {
   465  			ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   466  			mixKey(&chainKey, &chainKey, ss[:])
   467  			setZero(ss[:])
   468  		}()
   469  
   470  		// add preshared key (psk)
   471  
   472  		var tau [blake2s.Size]byte
   473  		var key [chacha20poly1305.KeySize]byte
   474  		KDF3(
   475  			&chainKey,
   476  			&tau,
   477  			&key,
   478  			chainKey[:],
   479  			handshake.presharedKey[:],
   480  		)
   481  		mixHash(&hash, &hash, tau[:])
   482  
   483  		// authenticate transcript
   484  
   485  		aead, _ := chacha20poly1305.New(key[:])
   486  		_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
   487  		if err != nil {
   488  			return false
   489  		}
   490  		mixHash(&hash, &hash, msg.Empty[:])
   491  		return true
   492  	}()
   493  
   494  	if !ok {
   495  		return nil
   496  	}
   497  
   498  	// update handshake state
   499  
   500  	handshake.mutex.Lock()
   501  
   502  	handshake.hash = hash
   503  	handshake.chainKey = chainKey
   504  	handshake.remoteIndex = msg.Sender
   505  	handshake.state = handshakeResponseConsumed
   506  
   507  	handshake.mutex.Unlock()
   508  
   509  	setZero(hash[:])
   510  	setZero(chainKey[:])
   511  
   512  	return lookup.peer
   513  }
   514  
   515  /* Derives a new keypair from the current handshake state
   516   *
   517   */
   518  func (peer *Peer) BeginSymmetricSession() error {
   519  	device := peer.device
   520  	handshake := &peer.handshake
   521  	handshake.mutex.Lock()
   522  	defer handshake.mutex.Unlock()
   523  
   524  	// derive keys
   525  
   526  	var isInitiator bool
   527  	var sendKey [chacha20poly1305.KeySize]byte
   528  	var recvKey [chacha20poly1305.KeySize]byte
   529  
   530  	if handshake.state == handshakeResponseConsumed {
   531  		KDF2(
   532  			&sendKey,
   533  			&recvKey,
   534  			handshake.chainKey[:],
   535  			nil,
   536  		)
   537  		isInitiator = true
   538  	} else if handshake.state == handshakeResponseCreated {
   539  		KDF2(
   540  			&recvKey,
   541  			&sendKey,
   542  			handshake.chainKey[:],
   543  			nil,
   544  		)
   545  		isInitiator = false
   546  	} else {
   547  		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
   548  	}
   549  
   550  	// zero handshake
   551  
   552  	setZero(handshake.chainKey[:])
   553  	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
   554  	setZero(handshake.localEphemeral[:])
   555  	peer.handshake.state = handshakeZeroed
   556  
   557  	// create AEAD instances
   558  
   559  	keypair := new(Keypair)
   560  	keypair.send, _ = chacha20poly1305.New(sendKey[:])
   561  	keypair.receive, _ = chacha20poly1305.New(recvKey[:])
   562  
   563  	setZero(sendKey[:])
   564  	setZero(recvKey[:])
   565  
   566  	keypair.created = time.Now()
   567  	keypair.replayFilter.Reset()
   568  	keypair.isInitiator = isInitiator
   569  	keypair.localIndex = peer.handshake.localIndex
   570  	keypair.remoteIndex = peer.handshake.remoteIndex
   571  
   572  	// remap index
   573  
   574  	device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
   575  	handshake.localIndex = 0
   576  
   577  	// rotate key pairs
   578  
   579  	keypairs := &peer.keypairs
   580  	keypairs.Lock()
   581  	defer keypairs.Unlock()
   582  
   583  	previous := keypairs.previous
   584  	next := keypairs.loadNext()
   585  	current := keypairs.current
   586  
   587  	if isInitiator {
   588  		if next != nil {
   589  			keypairs.storeNext(nil)
   590  			keypairs.previous = next
   591  			device.DeleteKeypair(current)
   592  		} else {
   593  			keypairs.previous = current
   594  		}
   595  		device.DeleteKeypair(previous)
   596  		keypairs.current = keypair
   597  	} else {
   598  		keypairs.storeNext(keypair)
   599  		device.DeleteKeypair(next)
   600  		keypairs.previous = nil
   601  		device.DeleteKeypair(previous)
   602  	}
   603  
   604  	return nil
   605  }
   606  
   607  func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
   608  	keypairs := &peer.keypairs
   609  
   610  	if keypairs.loadNext() != receivedKeypair {
   611  		return false
   612  	}
   613  	keypairs.Lock()
   614  	defer keypairs.Unlock()
   615  	if keypairs.loadNext() != receivedKeypair {
   616  		return false
   617  	}
   618  	old := keypairs.previous
   619  	keypairs.previous = keypairs.current
   620  	peer.device.DeleteKeypair(old)
   621  	keypairs.current = keypairs.loadNext()
   622  	keypairs.storeNext(nil)
   623  	return true
   624  }