github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/noise-protocol.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"sync"
    12  	"time"
    13  
    14  	"golang.org/x/crypto/blake2s"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  	"golang.org/x/crypto/poly1305"
    17  	"github.com/cawidtu/notwireguard-go/tai64n"
    18  )
    19  
    20  type handshakeState int
    21  
    22  const (
    23  	handshakeZeroed = handshakeState(iota)
    24  	handshakeInitiationCreated
    25  	handshakeInitiationConsumed
    26  	handshakeResponseCreated
    27  	handshakeResponseConsumed
    28  )
    29  
    30  func (hs handshakeState) String() string {
    31  	switch hs {
    32  	case handshakeZeroed:
    33  		return "handshakeZeroed"
    34  	case handshakeInitiationCreated:
    35  		return "handshakeInitiationCreated"
    36  	case handshakeInitiationConsumed:
    37  		return "handshakeInitiationConsumed"
    38  	case handshakeResponseCreated:
    39  		return "handshakeResponseCreated"
    40  	case handshakeResponseConsumed:
    41  		return "handshakeResponseConsumed"
    42  	default:
    43  		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
    44  	}
    45  }
    46  
    47  const (
    48  	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
    49  	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com"
    50  	WGLabelMAC1       = "mac1----"
    51  	WGLabelCookie     = "cookie--"
    52  )
    53  
    54  const (
    55  	MessageInitiationType  = 1
    56  	MessageResponseType    = 2
    57  	MessageCookieReplyType = 3
    58  	MessageTransportType   = 4
    59  )
    60  
    61  const (
    62  	MessageInitiationSize      = 180// was 148, prolonged by 32 bytes obfuscionkey 
    63                                          // size of handshake initiation message
    64  	MessageResponseSize        = 92                                            // size of response message
    65  	MessageCookieReplySize     = 64                                            // size of cookie reply message
    66  	MessageTransportHeaderSize = 16                                            // size of data preceding content in transport message
    67  	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
    68  	MessageKeepaliveSize       = MessageTransportSize                          // size of keepalive
    69  	MessageHandshakeSize       = MessageInitiationSize                         // size of largest handshake related message
    70  	NoiseHashLen               = 32 // we use blake2s.size128
    71  	NoiseObfuscateLenMax       = 192
    72     // the final two fields were introduced with the obfuscation code
    73  )
    74  
    75  const (
    76  	MessageTransportOffsetReceiver = 4
    77  	MessageTransportOffsetCounter  = 8
    78  	MessageTransportOffsetContent  = 16
    79  )
    80  
    81  /* Type is an 8-bit field, followed by 3 nul bytes,
    82   * by marshalling the messages in little-endian byteorder
    83   * we can treat these as a 32-bit unsigned int (for now)
    84   *
    85   */
    86  
    87  type MessageInitiation struct {
    88  	Type      uint32
    89  	Sender    uint32
    90          Obfuscator [NoisePublicKeySize]byte // new for obfuscation
    91  	Ephemeral NoisePublicKey
    92  	Static    [NoisePublicKeySize + poly1305.TagSize]byte
    93  	Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
    94  	MAC1      [blake2s.Size128]byte
    95  	MAC2      [blake2s.Size128]byte
    96  }
    97  
    98  type MessageResponse struct {
    99  	Type      uint32
   100  	Sender    uint32
   101  	Receiver  uint32
   102  	Ephemeral NoisePublicKey
   103  	Empty     [poly1305.TagSize]byte
   104  	MAC1      [blake2s.Size128]byte
   105  	MAC2      [blake2s.Size128]byte
   106  }
   107  
   108  type MessageTransport struct {
   109  	Type     uint32
   110  	Receiver uint32
   111  	Counter  uint64
   112  	Content  []byte
   113  }
   114  
   115  type MessageCookieReply struct {
   116  	Type     uint32
   117  	Receiver uint32
   118  	Nonce    [chacha20poly1305.NonceSizeX]byte
   119  	Cookie   [blake2s.Size128 + poly1305.TagSize]byte
   120  }
   121  
   122  type Handshake struct {
   123  	state                     handshakeState
   124  	mutex                     sync.RWMutex
   125  	hash                      [blake2s.Size]byte       // hash value
   126  	chainKey                  [blake2s.Size]byte       // chain key
   127  	presharedKey              NoisePresharedKey        // psk
   128  	localEphemeral            NoisePrivateKey          // ephemeral secret key
   129  	localIndex                uint32                   // used to clear hash-table
   130  	remoteIndex               uint32                   // index for sending
   131  	remoteStatic              NoisePublicKey           // long term key
   132  	remoteEphemeral           NoisePublicKey           // ephemeral public key
   133  	precomputedStaticStatic   [NoisePublicKeySize]byte // precomputed shared secret
   134  	lastTimestamp             tai64n.Timestamp
   135  	lastInitiationConsumption time.Time
   136  	lastSentHandshake         time.Time
   137  	obfuscator                [NoisePublicKeySize]byte
   138  }
   139  
   140  var (
   141  	InitialChainKey [blake2s.Size]byte
   142  	InitialHash     [blake2s.Size]byte
   143  	ZeroNonce       [chacha20poly1305.NonceSize]byte
   144  )
   145  
   146  func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
   147  	KDF1(dst, c[:], data)
   148  }
   149  
   150  func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
   151  	hash, _ := blake2s.New256(nil)
   152  	hash.Write(h[:])
   153  	hash.Write(data)
   154  	hash.Sum(dst[:0])
   155  	hash.Reset()
   156  }
   157  
   158  func (h *Handshake) Clear() {
   159  	setZero(h.localEphemeral[:])
   160  	setZero(h.remoteEphemeral[:])
   161  	setZero(h.chainKey[:])
   162  	setZero(h.hash[:])
   163  	h.localIndex = 0
   164  	h.state = handshakeZeroed
   165  }
   166  
   167  func (h *Handshake) mixHash(data []byte) {
   168  	mixHash(&h.hash, &h.hash, data)
   169  }
   170  
   171  func (h *Handshake) mixKey(data []byte) {
   172  	mixKey(&h.chainKey, &h.chainKey, data)
   173  }
   174  
   175  /* Do basic precomputations
   176   */
   177  func init() {
   178  	InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
   179  	mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
   180  }
   181  
   182  func wgNoiseCreateObfuscator(pubkey [NoisePublicKeySize]byte) [NoisePublicKeySize]byte {
   183      const obfsLabel = "obfs----\x00" // the C uses also the terminating 0 byte for the computation of the hash
   184      var obfuscator [NoisePublicKeySize]byte
   185  
   186      var err error
   187      hash, err := blake2s.New256(nil)
   188  
   189      if err != nil {
   190          panic(err)
   191      }
   192  
   193      hash.Write([]byte(obfsLabel))
   194      hash.Write(pubkey[:])
   195      copy(obfuscator[:], hash.Sum(nil))
   196  
   197      return obfuscator
   198  }
   199  
   200  func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
   201  	errZeroECDHResult := errors.New("ECDH returned all zeros")
   202  
   203  	device.staticIdentity.RLock()
   204  	defer device.staticIdentity.RUnlock()
   205  
   206  	handshake := &peer.handshake
   207  	handshake.mutex.Lock()
   208  	defer handshake.mutex.Unlock()
   209  
   210  	// create ephemeral key
   211  	var err error
   212  	handshake.hash = InitialHash
   213  	handshake.chainKey = InitialChainKey
   214  	handshake.localEphemeral, err = newPrivateKey()
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	handshake.mixHash(handshake.remoteStatic[:])
   220  
   221  	msg := MessageInitiation{
   222  		Type:      MessageInitiationType,
   223  		Ephemeral: handshake.localEphemeral.publicKey(),
   224  	}
   225  
   226  	handshake.mixKey(msg.Ephemeral[:])
   227  	handshake.mixHash(msg.Ephemeral[:])
   228  
   229  	// encrypt static key
   230  	ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   231  	if isZero(ss[:]) {
   232  		return nil, errZeroECDHResult
   233  	}
   234  	var key [chacha20poly1305.KeySize]byte
   235  	KDF2(
   236  		&handshake.chainKey,
   237  		&key,
   238  		handshake.chainKey[:],
   239  		ss[:],
   240  	)
   241  	aead, _ := chacha20poly1305.New(key[:])
   242  	aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
   243  	handshake.mixHash(msg.Static[:])
   244  
   245  	// encrypt timestamp
   246  	if isZero(handshake.precomputedStaticStatic[:]) {
   247  		return nil, errZeroECDHResult
   248  	}
   249  	KDF2(
   250  		&handshake.chainKey,
   251  		&key,
   252  		handshake.chainKey[:],
   253  		handshake.precomputedStaticStatic[:],
   254  	)
   255  	timestamp := tai64n.Now()
   256  	aead, _ = chacha20poly1305.New(key[:])
   257  	aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
   258  	
   259  	// assign index
   260  	device.indexTable.Delete(handshake.localIndex)
   261  	msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  	handshake.localIndex = msg.Sender
   266  
   267  	handshake.mixHash(msg.Timestamp[:])
   268  
   269          msg.Obfuscator = handshake.obfuscator // put obfuscation key also in  handshake messsage
   270          
   271  	handshake.state = handshakeInitiationCreated
   272  	return &msg, nil
   273  }
   274  
   275  func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
   276  	var (
   277  		hash     [blake2s.Size]byte
   278  		chainKey [blake2s.Size]byte
   279  	)
   280  
   281  	if msg.Type != MessageInitiationType {
   282  		return nil
   283  	}
   284  
   285  	device.staticIdentity.RLock()
   286  	defer device.staticIdentity.RUnlock()
   287  
   288  	mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
   289  	mixHash(&hash, &hash, msg.Ephemeral[:])
   290  	mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
   291  
   292  	// decrypt static key
   293  	var err error
   294  	var peerPK NoisePublicKey
   295  	var key [chacha20poly1305.KeySize]byte
   296  	ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   297  	if isZero(ss[:]) {
   298  		return nil
   299  	}
   300  	KDF2(&chainKey, &key, chainKey[:], ss[:])
   301  	aead, _ := chacha20poly1305.New(key[:])
   302  	_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
   303  	if err != nil {
   304  		return nil
   305  	}
   306  	mixHash(&hash, &hash, msg.Static[:])
   307  
   308  	// lookup peer
   309  
   310  	peer := device.LookupPeer(peerPK)
   311  	if peer == nil || !peer.isRunning.Get() {
   312  		return nil
   313  	}
   314  
   315  	handshake := &peer.handshake
   316  
   317  	// verify identity
   318  
   319  	var timestamp tai64n.Timestamp
   320  
   321  	handshake.mutex.RLock()
   322  
   323  	if isZero(handshake.precomputedStaticStatic[:]) {
   324  		handshake.mutex.RUnlock()
   325  		return nil
   326  	}
   327  	KDF2(
   328  		&chainKey,
   329  		&key,
   330  		chainKey[:],
   331  		handshake.precomputedStaticStatic[:],
   332  	)
   333  	aead, _ = chacha20poly1305.New(key[:])
   334  	_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
   335  	if err != nil {
   336  		handshake.mutex.RUnlock()
   337  		return nil
   338  	}
   339  	mixHash(&hash, &hash, msg.Timestamp[:])
   340  
   341  	// protect against replay & flood
   342  
   343  	replay := !timestamp.After(handshake.lastTimestamp)
   344  	flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
   345  	handshake.mutex.RUnlock()
   346  	if replay {
   347  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
   348  		return nil
   349  	}
   350  	if flood {
   351  		device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
   352  		return nil
   353  	}
   354  
   355  	// update handshake state
   356  
   357  	handshake.mutex.Lock()
   358  
   359  	handshake.hash = hash
   360  	handshake.chainKey = chainKey
   361  	handshake.remoteIndex = msg.Sender
   362  	handshake.remoteEphemeral = msg.Ephemeral
   363  	if timestamp.After(handshake.lastTimestamp) {
   364  		handshake.lastTimestamp = timestamp
   365  	}
   366  	now := time.Now()
   367  	if now.After(handshake.lastInitiationConsumption) {
   368  		handshake.lastInitiationConsumption = now
   369  	}
   370  	handshake.state = handshakeInitiationConsumed
   371  
   372  	handshake.mutex.Unlock()
   373  
   374  	setZero(hash[:])
   375  	setZero(chainKey[:])
   376  
   377  	return peer
   378  }
   379  
   380  func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
   381  	handshake := &peer.handshake
   382  	handshake.mutex.Lock()
   383  	defer handshake.mutex.Unlock()
   384  
   385  	if handshake.state != handshakeInitiationConsumed {
   386  		return nil, errors.New("handshake initiation must be consumed first")
   387  	}
   388  
   389  	// assign index
   390  
   391  	var err error
   392  	device.indexTable.Delete(handshake.localIndex)
   393  	handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
   394  	if err != nil {
   395  		return nil, err
   396  	}
   397  
   398  	var msg MessageResponse
   399  	msg.Type = MessageResponseType
   400  	msg.Sender = handshake.localIndex
   401  	msg.Receiver = handshake.remoteIndex
   402  
   403  	// create ephemeral key
   404  
   405  	handshake.localEphemeral, err = newPrivateKey()
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  	msg.Ephemeral = handshake.localEphemeral.publicKey()
   410  	handshake.mixHash(msg.Ephemeral[:])
   411  	handshake.mixKey(msg.Ephemeral[:])
   412  
   413  	func() {
   414  		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
   415  		handshake.mixKey(ss[:])
   416  		ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
   417  		handshake.mixKey(ss[:])
   418  	}()
   419  
   420  	// add preshared key
   421  
   422  	var tau [blake2s.Size]byte
   423  	var key [chacha20poly1305.KeySize]byte
   424  
   425  	KDF3(
   426  		&handshake.chainKey,
   427  		&tau,
   428  		&key,
   429  		handshake.chainKey[:],
   430  		handshake.presharedKey[:],
   431  	)
   432  
   433  	handshake.mixHash(tau[:])
   434  
   435  	func() {
   436  		aead, _ := chacha20poly1305.New(key[:])
   437  		aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
   438  		handshake.mixHash(msg.Empty[:])
   439  	}()
   440  
   441  	handshake.state = handshakeResponseCreated
   442  
   443  	return &msg, nil
   444  }
   445  
   446  func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
   447  	if msg.Type != MessageResponseType {
   448  		return nil
   449  	}
   450  
   451  	// lookup handshake by receiver
   452  
   453  	lookup := device.indexTable.Lookup(msg.Receiver)
   454  	handshake := lookup.handshake
   455  	if handshake == nil {
   456  		return nil
   457  	}
   458  
   459  	var (
   460  		hash     [blake2s.Size]byte
   461  		chainKey [blake2s.Size]byte
   462  	)
   463  
   464  	ok := func() bool {
   465  		// lock handshake state
   466  
   467  		handshake.mutex.RLock()
   468  		defer handshake.mutex.RUnlock()
   469  
   470  		if handshake.state != handshakeInitiationCreated {
   471  			return false
   472  		}
   473  
   474  		// lock private key for reading
   475  
   476  		device.staticIdentity.RLock()
   477  		defer device.staticIdentity.RUnlock()
   478  
   479  		// finish 3-way DH
   480  
   481  		mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
   482  		mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
   483  
   484  		func() {
   485  			ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
   486  			mixKey(&chainKey, &chainKey, ss[:])
   487  			setZero(ss[:])
   488  		}()
   489  
   490  		func() {
   491  			ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
   492  			mixKey(&chainKey, &chainKey, ss[:])
   493  			setZero(ss[:])
   494  		}()
   495  
   496  		// add preshared key (psk)
   497  
   498  		var tau [blake2s.Size]byte
   499  		var key [chacha20poly1305.KeySize]byte
   500  		KDF3(
   501  			&chainKey,
   502  			&tau,
   503  			&key,
   504  			chainKey[:],
   505  			handshake.presharedKey[:],
   506  		)
   507  		mixHash(&hash, &hash, tau[:])
   508  
   509  		// authenticate transcript
   510  
   511  		aead, _ := chacha20poly1305.New(key[:])
   512  		_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
   513  		if err != nil {
   514  			return false
   515  		}
   516  		mixHash(&hash, &hash, msg.Empty[:])
   517  		return true
   518  	}()
   519  
   520  	if !ok {
   521  		return nil
   522  	}
   523  
   524  	// update handshake state
   525  
   526  	handshake.mutex.Lock()
   527  
   528  	handshake.hash = hash
   529  	handshake.chainKey = chainKey
   530  	handshake.remoteIndex = msg.Sender
   531  	handshake.state = handshakeResponseConsumed
   532  
   533  	handshake.mutex.Unlock()
   534  
   535  	setZero(hash[:])
   536  	setZero(chainKey[:])
   537  
   538  	return lookup.peer
   539  }
   540  
   541  /* Derives a new keypair from the current handshake state
   542   *
   543   */
   544  func (peer *Peer) BeginSymmetricSession() error {
   545  	device := peer.device
   546  	handshake := &peer.handshake
   547  	handshake.mutex.Lock()
   548  	defer handshake.mutex.Unlock()
   549  
   550  	// derive keys
   551  
   552  	var isInitiator bool
   553  	var sendKey [chacha20poly1305.KeySize]byte
   554  	var recvKey [chacha20poly1305.KeySize]byte
   555  
   556  	if handshake.state == handshakeResponseConsumed {
   557  		KDF2(
   558  			&sendKey,
   559  			&recvKey,
   560  			handshake.chainKey[:],
   561  			nil,
   562  		)
   563  		isInitiator = true
   564  	} else if handshake.state == handshakeResponseCreated {
   565  		KDF2(
   566  			&recvKey,
   567  			&sendKey,
   568  			handshake.chainKey[:],
   569  			nil,
   570  		)
   571  		isInitiator = false
   572  	} else {
   573  		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
   574  	}
   575  
   576  	// zero handshake
   577  
   578  	setZero(handshake.chainKey[:])
   579  	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
   580  	setZero(handshake.localEphemeral[:])
   581  	peer.handshake.state = handshakeZeroed
   582  
   583  	// create AEAD instances
   584  
   585  	keypair := new(Keypair)
   586  	keypair.send, _ = chacha20poly1305.New(sendKey[:])
   587  	keypair.receive, _ = chacha20poly1305.New(recvKey[:])
   588  
   589  	setZero(sendKey[:])
   590  	setZero(recvKey[:])
   591  
   592  	keypair.created = time.Now()
   593  	keypair.replayFilter.Reset()
   594  	keypair.isInitiator = isInitiator
   595  	keypair.localIndex = peer.handshake.localIndex
   596  	keypair.remoteIndex = peer.handshake.remoteIndex
   597  
   598  	// remap index
   599  
   600  	device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
   601  	handshake.localIndex = 0
   602  
   603  	// rotate key pairs
   604  
   605  	keypairs := &peer.keypairs
   606  	keypairs.Lock()
   607  	defer keypairs.Unlock()
   608  
   609  	previous := keypairs.previous
   610  	next := keypairs.loadNext()
   611  	current := keypairs.current
   612  
   613  	if isInitiator {
   614  		if next != nil {
   615  			keypairs.storeNext(nil)
   616  			keypairs.previous = next
   617  			device.DeleteKeypair(current)
   618  		} else {
   619  			keypairs.previous = current
   620  		}
   621  		device.DeleteKeypair(previous)
   622  		keypairs.current = keypair
   623  	} else {
   624  		keypairs.storeNext(keypair)
   625  		device.DeleteKeypair(next)
   626  		keypairs.previous = nil
   627  		device.DeleteKeypair(previous)
   628  	}
   629  
   630  	return nil
   631  }
   632  
   633  func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
   634  	keypairs := &peer.keypairs
   635  
   636  	if keypairs.loadNext() != receivedKeypair {
   637  		return false
   638  	}
   639  	keypairs.Lock()
   640  	defer keypairs.Unlock()
   641  	if keypairs.loadNext() != receivedKeypair {
   642  		return false
   643  	}
   644  	old := keypairs.previous
   645  	keypairs.previous = keypairs.current
   646  	peer.device.DeleteKeypair(old)
   647  	keypairs.current = keypairs.loadNext()
   648  	keypairs.storeNext(nil)
   649  	return true
   650  }