github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/noise-protocol.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  	"golang.org/x/crypto/poly1305"
    17  
    18  	"github.com/bepass-org/wireguard-go/tai64n"
    19  )
    20  
    21  type handshakeState int
    22  
    23  const (
    24  	handshakeZeroed = handshakeState(iota)
    25  	handshakeInitiationCreated
    26  	handshakeInitiationConsumed
    27  	handshakeResponseCreated
    28  	handshakeResponseConsumed
    29  )
    30  
    31  func (hs handshakeState) String() string {
    32  	switch hs {
    33  	case handshakeZeroed:
    34  		return "handshakeZeroed"
    35  	case handshakeInitiationCreated:
    36  		return "handshakeInitiationCreated"
    37  	case handshakeInitiationConsumed:
    38  		return "handshakeInitiationConsumed"
    39  	case handshakeResponseCreated:
    40  		return "handshakeResponseCreated"
    41  	case handshakeResponseConsumed:
    42  		return "handshakeResponseConsumed"
    43  	default:
    44  		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
    45  	}
    46  }
    47  
    48  const (
    49  	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
    50  	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
    51  	WGLabelMAC1       = "mac1----"
    52  	WGLabelCookie     = "cookie--"
    53  )
    54  
    55  const (
    56  	MessageInitiationType  = 1
    57  	MessageResponseType    = 2
    58  	MessageCookieReplyType = 3
    59  	MessageTransportType   = 4
    60  )
    61  
    62  const (
    63  	MessageInitiationSize      = 148                                           // size of handshake initiation message
    64  	MessageResponseSize        = 92                                            // size of response message
    65  	MessageCookieReplySize     = 64                                            // size of cookie reply message
    66  	MessageTransportHeaderSize = 16                                            // size of data preceding content in transport message
    67  	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
    68  	MessageKeepaliveSize       = MessageTransportSize                          // size of keepalive
    69  	MessageHandshakeSize       = MessageInitiationSize                         // size of largest handshake related message
    70  )
    71  
    72  const (
    73  	MessageTransportOffsetReceiver = 4
    74  	MessageTransportOffsetCounter  = 8
    75  	MessageTransportOffsetContent  = 16
    76  )
    77  
    78  /* Type is an 8-bit field, followed by 3 nul bytes,
    79   * by marshalling the messages in little-endian byteorder
    80   * we can treat these as a 32-bit unsigned int (for now)
    81   *
    82   */
    83  
    84  type MessageInitiation struct {
    85  	Type      uint32
    86  	Sender    uint32
    87  	Ephemeral NoisePublicKey
    88  	Static    [NoisePublicKeySize + poly1305.TagSize]byte
    89  	Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
    90  	MAC1      [blake2s.Size128]byte
    91  	MAC2      [blake2s.Size128]byte
    92  }
    93  
    94  type MessageResponse struct {
    95  	Type      uint32
    96  	Sender    uint32
    97  	Receiver  uint32
    98  	Ephemeral NoisePublicKey
    99  	Empty     [poly1305.TagSize]byte
   100  	MAC1      [blake2s.Size128]byte
   101  	MAC2      [blake2s.Size128]byte
   102  }
   103  
   104  type MessageTransport struct {
   105  	Type     uint32
   106  	Receiver uint32
   107  	Counter  uint64
   108  	Content  []byte
   109  }
   110  
   111  type MessageCookieReply struct {
   112  	Type     uint32
   113  	Receiver uint32
   114  	Nonce    [chacha20poly1305.NonceSizeX]byte
   115  	Cookie   [blake2s.Size128 + poly1305.TagSize]byte
   116  }
   117  
   118  type Handshake struct {
   119  	state                     handshakeState
   120  	mutex                     sync.RWMutex
   121  	hash                      [blake2s.Size]byte       // hash value
   122  	chainKey                  [blake2s.Size]byte       // chain key
   123  	presharedKey              NoisePresharedKey        // psk
   124  	localEphemeral            NoisePrivateKey          // ephemeral secret key
   125  	localIndex                uint32                   // used to clear hash-table
   126  	remoteIndex               uint32                   // index for sending
   127  	remoteStatic              NoisePublicKey           // long term key
   128  	remoteEphemeral           NoisePublicKey           // ephemeral public key
   129  	precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
   130  	lastTimestamp             tai64n.Timestamp
   131  	lastInitiationConsumption time.Time
   132  	lastSentHandshake         time.Time
   133  }
   134  
   135  var (
   136  	InitialChainKey [blake2s.Size]byte
   137  	InitialHash     [blake2s.Size]byte
   138  	ZeroNonce       [chacha20poly1305.NonceSize]byte
   139  )
   140  
   141  func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
   142  	KDF1(dst, c[:], data)
   143  }
   144  
   145  func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
   146  	hash, _ := blake2s.New256(nil)
   147  	hash.Write(h[:])
   148  	hash.Write(data)
   149  	hash.Sum(dst[:0])
   150  	hash.Reset()
   151  }
   152  
   153  func (h *Handshake) Clear() {
   154  	setZero(h.localEphemeral[:])
   155  	setZero(h.remoteEphemeral[:])
   156  	setZero(h.chainKey[:])
   157  	setZero(h.hash[:])
   158  	h.localIndex = 0
   159  	h.state = handshakeZeroed
   160  }
   161  
   162  func (h *Handshake) mixHash(data []byte) {
   163  	mixHash(&h.hash, &h.hash, data)
   164  }
   165  
   166  func (h *Handshake) mixKey(data []byte) {
   167  	mixKey(&h.chainKey, &h.chainKey, data)
   168  }
   169  
   170  /* Do basic precomputations
   171   */
   172  func init() {
   173  	InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
   174  	mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
   175  }
   176  
   177  func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
   178  	device.staticIdentity.RLock()
   179  	defer device.staticIdentity.RUnlock()
   180  
   181  	handshake := &peer.handshake
   182  	handshake.mutex.Lock()
   183  	defer handshake.mutex.Unlock()
   184  
   185  	// create ephemeral key
   186  	var err error
   187  	handshake.hash = InitialHash
   188  	handshake.chainKey = InitialChainKey
   189  	handshake.localEphemeral, err = newPrivateKey()
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	handshake.mixHash(handshake.remoteStatic[:])
   195  
   196  	msg := MessageInitiation{
   197  		Type:      MessageInitiationType,
   198  		Ephemeral: handshake.localEphemeral.publicKey(),
   199  	}
   200  
   201  	handshake.mixKey(msg.Ephemeral[:])
   202  	handshake.mixHash(msg.Ephemeral[:])
   203  
   204  	// encrypt static key
   205  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	var key [chacha20poly1305.KeySize]byte
   210  	KDF2(
   211  		&handshake.chainKey,
   212  		&key,
   213  		handshake.chainKey[:],
   214  		ss[:],
   215  	)
   216  	aead, _ := chacha20poly1305.New(key[:])
   217  	aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
   218  	handshake.mixHash(msg.Static[:])
   219  
   220  	// encrypt timestamp
   221  	if isZero(handshake.precomputedStaticStatic[:]) {
   222  		return nil, errInvalidPublicKey
   223  	}
   224  	KDF2(
   225  		&handshake.chainKey,
   226  		&key,
   227  		handshake.chainKey[:],
   228  		handshake.precomputedStaticStatic[:],
   229  	)
   230  	timestamp := tai64n.Now()
   231  	aead, _ = chacha20poly1305.New(key[:])
   232  	aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
   233  
   234  	// assign index
   235  	device.indexTable.Delete(handshake.localIndex)
   236  	msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  	handshake.localIndex = msg.Sender
   241  
   242  	handshake.mixHash(msg.Timestamp[:])
   243  	handshake.state = handshakeInitiationCreated
   244  	return &msg, nil
   245  }
   246  
   247  func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
   248  	var (
   249  		hash     [blake2s.Size]byte
   250  		chainKey [blake2s.Size]byte
   251  	)
   252  
   253  	if msg.Type != MessageInitiationType {
   254  		return nil
   255  	}
   256  
   257  	device.staticIdentity.RLock()
   258  	defer device.staticIdentity.RUnlock()
   259  
   260  	mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
   261  	mixHash(&hash, &hash, msg.Ephemeral[:])
   262  	mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
   263  
   264  	// decrypt static key
   265  	var peerPK NoisePublicKey
   266  	var key [chacha20poly1305.KeySize]byte
   267  	ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   268  	if err != nil {
   269  		return nil
   270  	}
   271  	KDF2(&chainKey, &key, chainKey[:], ss[:])
   272  	aead, _ := chacha20poly1305.New(key[:])
   273  	_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
   274  	if err != nil {
   275  		return nil
   276  	}
   277  	mixHash(&hash, &hash, msg.Static[:])
   278  
   279  	// lookup peer
   280  
   281  	peer := device.LookupPeer(peerPK)
   282  	if peer == nil || !peer.isRunning.Load() {
   283  		return nil
   284  	}
   285  
   286  	handshake := &peer.handshake
   287  
   288  	// verify identity
   289  
   290  	var timestamp tai64n.Timestamp
   291  
   292  	handshake.mutex.RLock()
   293  
   294  	if isZero(handshake.precomputedStaticStatic[:]) {
   295  		handshake.mutex.RUnlock()
   296  		return nil
   297  	}
   298  	KDF2(
   299  		&chainKey,
   300  		&key,
   301  		chainKey[:],
   302  		handshake.precomputedStaticStatic[:],
   303  	)
   304  	aead, _ = chacha20poly1305.New(key[:])
   305  	_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
   306  	if err != nil {
   307  		handshake.mutex.RUnlock()
   308  		return nil
   309  	}
   310  	mixHash(&hash, &hash, msg.Timestamp[:])
   311  
   312  	// protect against replay & flood
   313  
   314  	replay := !timestamp.After(handshake.lastTimestamp)
   315  	flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
   316  	handshake.mutex.RUnlock()
   317  	if replay {
   318  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
   319  		return nil
   320  	}
   321  	if flood {
   322  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
   323  		return nil
   324  	}
   325  
   326  	// update handshake state
   327  
   328  	handshake.mutex.Lock()
   329  
   330  	handshake.hash = hash
   331  	handshake.chainKey = chainKey
   332  	handshake.remoteIndex = msg.Sender
   333  	handshake.remoteEphemeral = msg.Ephemeral
   334  	if timestamp.After(handshake.lastTimestamp) {
   335  		handshake.lastTimestamp = timestamp
   336  	}
   337  	now := time.Now()
   338  	if now.After(handshake.lastInitiationConsumption) {
   339  		handshake.lastInitiationConsumption = now
   340  	}
   341  	handshake.state = handshakeInitiationConsumed
   342  
   343  	handshake.mutex.Unlock()
   344  
   345  	setZero(hash[:])
   346  	setZero(chainKey[:])
   347  
   348  	return peer
   349  }
   350  
   351  func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
   352  	handshake := &peer.handshake
   353  	handshake.mutex.Lock()
   354  	defer handshake.mutex.Unlock()
   355  
   356  	if handshake.state != handshakeInitiationConsumed {
   357  		return nil, errors.New("handshake initiation must be consumed first")
   358  	}
   359  
   360  	// assign index
   361  
   362  	var err error
   363  	device.indexTable.Delete(handshake.localIndex)
   364  	handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	var msg MessageResponse
   370  	msg.Type = MessageResponseType
   371  	msg.Sender = handshake.localIndex
   372  	msg.Receiver = handshake.remoteIndex
   373  
   374  	// create ephemeral key
   375  
   376  	handshake.localEphemeral, err = newPrivateKey()
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  	msg.Ephemeral = handshake.localEphemeral.publicKey()
   381  	handshake.mixHash(msg.Ephemeral[:])
   382  	handshake.mixKey(msg.Ephemeral[:])
   383  
   384  	ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  	handshake.mixKey(ss[:])
   389  	ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   390  	if err != nil {
   391  		return nil, err
   392  	}
   393  	handshake.mixKey(ss[:])
   394  
   395  	// add preshared key
   396  
   397  	var tau [blake2s.Size]byte
   398  	var key [chacha20poly1305.KeySize]byte
   399  
   400  	KDF3(
   401  		&handshake.chainKey,
   402  		&tau,
   403  		&key,
   404  		handshake.chainKey[:],
   405  		handshake.presharedKey[:],
   406  	)
   407  
   408  	handshake.mixHash(tau[:])
   409  
   410  	aead, _ := chacha20poly1305.New(key[:])
   411  	aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
   412  	handshake.mixHash(msg.Empty[:])
   413  
   414  	handshake.state = handshakeResponseCreated
   415  
   416  	return &msg, nil
   417  }
   418  
   419  func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
   420  	if msg.Type != MessageResponseType {
   421  		return nil
   422  	}
   423  
   424  	// lookup handshake by receiver
   425  
   426  	lookup := device.indexTable.Lookup(msg.Receiver)
   427  	handshake := lookup.handshake
   428  	if handshake == nil {
   429  		return nil
   430  	}
   431  
   432  	var (
   433  		hash     [blake2s.Size]byte
   434  		chainKey [blake2s.Size]byte
   435  	)
   436  
   437  	ok := func() bool {
   438  		// lock handshake state
   439  
   440  		handshake.mutex.RLock()
   441  		defer handshake.mutex.RUnlock()
   442  
   443  		if handshake.state != handshakeInitiationCreated {
   444  			return false
   445  		}
   446  
   447  		// lock private key for reading
   448  
   449  		device.staticIdentity.RLock()
   450  		defer device.staticIdentity.RUnlock()
   451  
   452  		// finish 3-way DH
   453  
   454  		mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
   455  		mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
   456  
   457  		ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
   458  		if err != nil {
   459  			return false
   460  		}
   461  		mixKey(&chainKey, &chainKey, ss[:])
   462  		setZero(ss[:])
   463  
   464  		ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   465  		if err != nil {
   466  			return false
   467  		}
   468  		mixKey(&chainKey, &chainKey, ss[:])
   469  		setZero(ss[:])
   470  
   471  		// add preshared key (psk)
   472  
   473  		var tau [blake2s.Size]byte
   474  		var key [chacha20poly1305.KeySize]byte
   475  		KDF3(
   476  			&chainKey,
   477  			&tau,
   478  			&key,
   479  			chainKey[:],
   480  			handshake.presharedKey[:],
   481  		)
   482  		mixHash(&hash, &hash, tau[:])
   483  
   484  		// authenticate transcript
   485  
   486  		aead, _ := chacha20poly1305.New(key[:])
   487  		_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
   488  		if err != nil {
   489  			return false
   490  		}
   491  		mixHash(&hash, &hash, msg.Empty[:])
   492  		return true
   493  	}()
   494  
   495  	if !ok {
   496  		return nil
   497  	}
   498  
   499  	// update handshake state
   500  
   501  	handshake.mutex.Lock()
   502  
   503  	handshake.hash = hash
   504  	handshake.chainKey = chainKey
   505  	handshake.remoteIndex = msg.Sender
   506  	handshake.state = handshakeResponseConsumed
   507  
   508  	handshake.mutex.Unlock()
   509  
   510  	setZero(hash[:])
   511  	setZero(chainKey[:])
   512  
   513  	return lookup.peer
   514  }
   515  
   516  /* Derives a new keypair from the current handshake state
   517   *
   518   */
   519  func (peer *Peer) BeginSymmetricSession() error {
   520  	device := peer.device
   521  	handshake := &peer.handshake
   522  	handshake.mutex.Lock()
   523  	defer handshake.mutex.Unlock()
   524  
   525  	// derive keys
   526  
   527  	var isInitiator bool
   528  	var sendKey [chacha20poly1305.KeySize]byte
   529  	var recvKey [chacha20poly1305.KeySize]byte
   530  
   531  	if handshake.state == handshakeResponseConsumed {
   532  		KDF2(
   533  			&sendKey,
   534  			&recvKey,
   535  			handshake.chainKey[:],
   536  			nil,
   537  		)
   538  		isInitiator = true
   539  	} else if handshake.state == handshakeResponseCreated {
   540  		KDF2(
   541  			&recvKey,
   542  			&sendKey,
   543  			handshake.chainKey[:],
   544  			nil,
   545  		)
   546  		isInitiator = false
   547  	} else {
   548  		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
   549  	}
   550  
   551  	// zero handshake
   552  
   553  	setZero(handshake.chainKey[:])
   554  	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
   555  	setZero(handshake.localEphemeral[:])
   556  	peer.handshake.state = handshakeZeroed
   557  
   558  	// create AEAD instances
   559  
   560  	keypair := new(Keypair)
   561  	keypair.send, _ = chacha20poly1305.New(sendKey[:])
   562  	keypair.receive, _ = chacha20poly1305.New(recvKey[:])
   563  
   564  	setZero(sendKey[:])
   565  	setZero(recvKey[:])
   566  
   567  	keypair.created = time.Now()
   568  	keypair.replayFilter.Reset()
   569  	keypair.isInitiator = isInitiator
   570  	keypair.localIndex = peer.handshake.localIndex
   571  	keypair.remoteIndex = peer.handshake.remoteIndex
   572  
   573  	// remap index
   574  
   575  	device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
   576  	handshake.localIndex = 0
   577  
   578  	// rotate key pairs
   579  
   580  	keypairs := &peer.keypairs
   581  	keypairs.Lock()
   582  	defer keypairs.Unlock()
   583  
   584  	previous := keypairs.previous
   585  	next := keypairs.next.Load()
   586  	current := keypairs.current
   587  
   588  	if isInitiator {
   589  		if next != nil {
   590  			keypairs.next.Store(nil)
   591  			keypairs.previous = next
   592  			device.DeleteKeypair(current)
   593  		} else {
   594  			keypairs.previous = current
   595  		}
   596  		device.DeleteKeypair(previous)
   597  		keypairs.current = keypair
   598  	} else {
   599  		keypairs.next.Store(keypair)
   600  		device.DeleteKeypair(next)
   601  		keypairs.previous = nil
   602  		device.DeleteKeypair(previous)
   603  	}
   604  
   605  	return nil
   606  }
   607  
   608  func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
   609  	keypairs := &peer.keypairs
   610  
   611  	if keypairs.next.Load() != receivedKeypair {
   612  		return false
   613  	}
   614  	keypairs.Lock()
   615  	defer keypairs.Unlock()
   616  	if keypairs.next.Load() != receivedKeypair {
   617  		return false
   618  	}
   619  	old := keypairs.previous
   620  	keypairs.previous = keypairs.current
   621  	peer.device.DeleteKeypair(old)
   622  	keypairs.current = keypairs.next.Load()
   623  	keypairs.next.Store(nil)
   624  	return true
   625  }