github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "golang.org/x/crypto/blake2s" 15 "golang.org/x/crypto/chacha20poly1305" 16 "golang.org/x/crypto/poly1305" 17 "github.com/cawidtu/notwireguard-go/tai64n" 18 ) 19 20 type handshakeState int 21 22 const ( 23 handshakeZeroed = handshakeState(iota) 24 handshakeInitiationCreated 25 handshakeInitiationConsumed 26 handshakeResponseCreated 27 handshakeResponseConsumed 28 ) 29 30 func (hs handshakeState) String() string { 31 switch hs { 32 case handshakeZeroed: 33 return "handshakeZeroed" 34 case handshakeInitiationCreated: 35 return "handshakeInitiationCreated" 36 case handshakeInitiationConsumed: 37 return "handshakeInitiationConsumed" 38 case handshakeResponseCreated: 39 return "handshakeResponseCreated" 40 case handshakeResponseConsumed: 41 return "handshakeResponseConsumed" 42 default: 43 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 44 } 45 } 46 47 const ( 48 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 49 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 50 WGLabelMAC1 = "mac1----" 51 WGLabelCookie = "cookie--" 52 ) 53 54 const ( 55 MessageInitiationType = 1 56 MessageResponseType = 2 57 MessageCookieReplyType = 3 58 MessageTransportType = 4 59 ) 60 61 const ( 62 MessageInitiationSize = 180// was 148, prolonged by 32 bytes obfuscionkey 63 // size of handshake initiation message 64 MessageResponseSize = 92 // size of response message 65 MessageCookieReplySize = 64 // size of cookie reply message 66 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 67 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 68 MessageKeepaliveSize = MessageTransportSize // size of keepalive 69 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 70 NoiseHashLen = 32 // we use blake2s.size128 71 NoiseObfuscateLenMax = 192 72 // the final two fields were introduced with the obfuscation code 73 ) 74 75 const ( 76 MessageTransportOffsetReceiver = 4 77 MessageTransportOffsetCounter = 8 78 MessageTransportOffsetContent = 16 79 ) 80 81 /* Type is an 8-bit field, followed by 3 nul bytes, 82 * by marshalling the messages in little-endian byteorder 83 * we can treat these as a 32-bit unsigned int (for now) 84 * 85 */ 86 87 type MessageInitiation struct { 88 Type uint32 89 Sender uint32 90 Obfuscator [NoisePublicKeySize]byte // new for obfuscation 91 Ephemeral NoisePublicKey 92 Static [NoisePublicKeySize + poly1305.TagSize]byte 93 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 94 MAC1 [blake2s.Size128]byte 95 MAC2 [blake2s.Size128]byte 96 } 97 98 type MessageResponse struct { 99 Type uint32 100 Sender uint32 101 Receiver uint32 102 Ephemeral NoisePublicKey 103 Empty [poly1305.TagSize]byte 104 MAC1 [blake2s.Size128]byte 105 MAC2 [blake2s.Size128]byte 106 } 107 108 type MessageTransport struct { 109 Type uint32 110 Receiver uint32 111 Counter uint64 112 Content []byte 113 } 114 115 type MessageCookieReply struct { 116 Type uint32 117 Receiver uint32 118 Nonce [chacha20poly1305.NonceSizeX]byte 119 Cookie [blake2s.Size128 + poly1305.TagSize]byte 120 } 121 122 type Handshake struct { 123 state handshakeState 124 mutex sync.RWMutex 125 hash [blake2s.Size]byte // hash value 126 chainKey [blake2s.Size]byte // chain key 127 presharedKey NoisePresharedKey // psk 128 localEphemeral NoisePrivateKey // ephemeral secret key 129 localIndex uint32 // used to clear hash-table 130 remoteIndex uint32 // index for sending 131 remoteStatic NoisePublicKey // long term key 132 remoteEphemeral NoisePublicKey // ephemeral public key 133 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 134 lastTimestamp tai64n.Timestamp 135 lastInitiationConsumption time.Time 136 lastSentHandshake time.Time 137 obfuscator [NoisePublicKeySize]byte 138 } 139 140 var ( 141 InitialChainKey [blake2s.Size]byte 142 InitialHash [blake2s.Size]byte 143 ZeroNonce [chacha20poly1305.NonceSize]byte 144 ) 145 146 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 147 KDF1(dst, c[:], data) 148 } 149 150 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 151 hash, _ := blake2s.New256(nil) 152 hash.Write(h[:]) 153 hash.Write(data) 154 hash.Sum(dst[:0]) 155 hash.Reset() 156 } 157 158 func (h *Handshake) Clear() { 159 setZero(h.localEphemeral[:]) 160 setZero(h.remoteEphemeral[:]) 161 setZero(h.chainKey[:]) 162 setZero(h.hash[:]) 163 h.localIndex = 0 164 h.state = handshakeZeroed 165 } 166 167 func (h *Handshake) mixHash(data []byte) { 168 mixHash(&h.hash, &h.hash, data) 169 } 170 171 func (h *Handshake) mixKey(data []byte) { 172 mixKey(&h.chainKey, &h.chainKey, data) 173 } 174 175 /* Do basic precomputations 176 */ 177 func init() { 178 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 179 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 180 } 181 182 func wgNoiseCreateObfuscator(pubkey [NoisePublicKeySize]byte) [NoisePublicKeySize]byte { 183 const obfsLabel = "obfs----\x00" // the C uses also the terminating 0 byte for the computation of the hash 184 var obfuscator [NoisePublicKeySize]byte 185 186 var err error 187 hash, err := blake2s.New256(nil) 188 189 if err != nil { 190 panic(err) 191 } 192 193 hash.Write([]byte(obfsLabel)) 194 hash.Write(pubkey[:]) 195 copy(obfuscator[:], hash.Sum(nil)) 196 197 return obfuscator 198 } 199 200 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 201 errZeroECDHResult := errors.New("ECDH returned all zeros") 202 203 device.staticIdentity.RLock() 204 defer device.staticIdentity.RUnlock() 205 206 handshake := &peer.handshake 207 handshake.mutex.Lock() 208 defer handshake.mutex.Unlock() 209 210 // create ephemeral key 211 var err error 212 handshake.hash = InitialHash 213 handshake.chainKey = InitialChainKey 214 handshake.localEphemeral, err = newPrivateKey() 215 if err != nil { 216 return nil, err 217 } 218 219 handshake.mixHash(handshake.remoteStatic[:]) 220 221 msg := MessageInitiation{ 222 Type: MessageInitiationType, 223 Ephemeral: handshake.localEphemeral.publicKey(), 224 } 225 226 handshake.mixKey(msg.Ephemeral[:]) 227 handshake.mixHash(msg.Ephemeral[:]) 228 229 // encrypt static key 230 ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 231 if isZero(ss[:]) { 232 return nil, errZeroECDHResult 233 } 234 var key [chacha20poly1305.KeySize]byte 235 KDF2( 236 &handshake.chainKey, 237 &key, 238 handshake.chainKey[:], 239 ss[:], 240 ) 241 aead, _ := chacha20poly1305.New(key[:]) 242 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 243 handshake.mixHash(msg.Static[:]) 244 245 // encrypt timestamp 246 if isZero(handshake.precomputedStaticStatic[:]) { 247 return nil, errZeroECDHResult 248 } 249 KDF2( 250 &handshake.chainKey, 251 &key, 252 handshake.chainKey[:], 253 handshake.precomputedStaticStatic[:], 254 ) 255 timestamp := tai64n.Now() 256 aead, _ = chacha20poly1305.New(key[:]) 257 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 258 259 // assign index 260 device.indexTable.Delete(handshake.localIndex) 261 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 262 if err != nil { 263 return nil, err 264 } 265 handshake.localIndex = msg.Sender 266 267 handshake.mixHash(msg.Timestamp[:]) 268 269 msg.Obfuscator = handshake.obfuscator // put obfuscation key also in handshake messsage 270 271 handshake.state = handshakeInitiationCreated 272 return &msg, nil 273 } 274 275 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 276 var ( 277 hash [blake2s.Size]byte 278 chainKey [blake2s.Size]byte 279 ) 280 281 if msg.Type != MessageInitiationType { 282 return nil 283 } 284 285 device.staticIdentity.RLock() 286 defer device.staticIdentity.RUnlock() 287 288 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 289 mixHash(&hash, &hash, msg.Ephemeral[:]) 290 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 291 292 // decrypt static key 293 var err error 294 var peerPK NoisePublicKey 295 var key [chacha20poly1305.KeySize]byte 296 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 297 if isZero(ss[:]) { 298 return nil 299 } 300 KDF2(&chainKey, &key, chainKey[:], ss[:]) 301 aead, _ := chacha20poly1305.New(key[:]) 302 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 303 if err != nil { 304 return nil 305 } 306 mixHash(&hash, &hash, msg.Static[:]) 307 308 // lookup peer 309 310 peer := device.LookupPeer(peerPK) 311 if peer == nil || !peer.isRunning.Get() { 312 return nil 313 } 314 315 handshake := &peer.handshake 316 317 // verify identity 318 319 var timestamp tai64n.Timestamp 320 321 handshake.mutex.RLock() 322 323 if isZero(handshake.precomputedStaticStatic[:]) { 324 handshake.mutex.RUnlock() 325 return nil 326 } 327 KDF2( 328 &chainKey, 329 &key, 330 chainKey[:], 331 handshake.precomputedStaticStatic[:], 332 ) 333 aead, _ = chacha20poly1305.New(key[:]) 334 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 335 if err != nil { 336 handshake.mutex.RUnlock() 337 return nil 338 } 339 mixHash(&hash, &hash, msg.Timestamp[:]) 340 341 // protect against replay & flood 342 343 replay := !timestamp.After(handshake.lastTimestamp) 344 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 345 handshake.mutex.RUnlock() 346 if replay { 347 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 348 return nil 349 } 350 if flood { 351 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 352 return nil 353 } 354 355 // update handshake state 356 357 handshake.mutex.Lock() 358 359 handshake.hash = hash 360 handshake.chainKey = chainKey 361 handshake.remoteIndex = msg.Sender 362 handshake.remoteEphemeral = msg.Ephemeral 363 if timestamp.After(handshake.lastTimestamp) { 364 handshake.lastTimestamp = timestamp 365 } 366 now := time.Now() 367 if now.After(handshake.lastInitiationConsumption) { 368 handshake.lastInitiationConsumption = now 369 } 370 handshake.state = handshakeInitiationConsumed 371 372 handshake.mutex.Unlock() 373 374 setZero(hash[:]) 375 setZero(chainKey[:]) 376 377 return peer 378 } 379 380 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 381 handshake := &peer.handshake 382 handshake.mutex.Lock() 383 defer handshake.mutex.Unlock() 384 385 if handshake.state != handshakeInitiationConsumed { 386 return nil, errors.New("handshake initiation must be consumed first") 387 } 388 389 // assign index 390 391 var err error 392 device.indexTable.Delete(handshake.localIndex) 393 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 394 if err != nil { 395 return nil, err 396 } 397 398 var msg MessageResponse 399 msg.Type = MessageResponseType 400 msg.Sender = handshake.localIndex 401 msg.Receiver = handshake.remoteIndex 402 403 // create ephemeral key 404 405 handshake.localEphemeral, err = newPrivateKey() 406 if err != nil { 407 return nil, err 408 } 409 msg.Ephemeral = handshake.localEphemeral.publicKey() 410 handshake.mixHash(msg.Ephemeral[:]) 411 handshake.mixKey(msg.Ephemeral[:]) 412 413 func() { 414 ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 415 handshake.mixKey(ss[:]) 416 ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 417 handshake.mixKey(ss[:]) 418 }() 419 420 // add preshared key 421 422 var tau [blake2s.Size]byte 423 var key [chacha20poly1305.KeySize]byte 424 425 KDF3( 426 &handshake.chainKey, 427 &tau, 428 &key, 429 handshake.chainKey[:], 430 handshake.presharedKey[:], 431 ) 432 433 handshake.mixHash(tau[:]) 434 435 func() { 436 aead, _ := chacha20poly1305.New(key[:]) 437 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 438 handshake.mixHash(msg.Empty[:]) 439 }() 440 441 handshake.state = handshakeResponseCreated 442 443 return &msg, nil 444 } 445 446 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 447 if msg.Type != MessageResponseType { 448 return nil 449 } 450 451 // lookup handshake by receiver 452 453 lookup := device.indexTable.Lookup(msg.Receiver) 454 handshake := lookup.handshake 455 if handshake == nil { 456 return nil 457 } 458 459 var ( 460 hash [blake2s.Size]byte 461 chainKey [blake2s.Size]byte 462 ) 463 464 ok := func() bool { 465 // lock handshake state 466 467 handshake.mutex.RLock() 468 defer handshake.mutex.RUnlock() 469 470 if handshake.state != handshakeInitiationCreated { 471 return false 472 } 473 474 // lock private key for reading 475 476 device.staticIdentity.RLock() 477 defer device.staticIdentity.RUnlock() 478 479 // finish 3-way DH 480 481 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 482 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 483 484 func() { 485 ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 486 mixKey(&chainKey, &chainKey, ss[:]) 487 setZero(ss[:]) 488 }() 489 490 func() { 491 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 492 mixKey(&chainKey, &chainKey, ss[:]) 493 setZero(ss[:]) 494 }() 495 496 // add preshared key (psk) 497 498 var tau [blake2s.Size]byte 499 var key [chacha20poly1305.KeySize]byte 500 KDF3( 501 &chainKey, 502 &tau, 503 &key, 504 chainKey[:], 505 handshake.presharedKey[:], 506 ) 507 mixHash(&hash, &hash, tau[:]) 508 509 // authenticate transcript 510 511 aead, _ := chacha20poly1305.New(key[:]) 512 _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 513 if err != nil { 514 return false 515 } 516 mixHash(&hash, &hash, msg.Empty[:]) 517 return true 518 }() 519 520 if !ok { 521 return nil 522 } 523 524 // update handshake state 525 526 handshake.mutex.Lock() 527 528 handshake.hash = hash 529 handshake.chainKey = chainKey 530 handshake.remoteIndex = msg.Sender 531 handshake.state = handshakeResponseConsumed 532 533 handshake.mutex.Unlock() 534 535 setZero(hash[:]) 536 setZero(chainKey[:]) 537 538 return lookup.peer 539 } 540 541 /* Derives a new keypair from the current handshake state 542 * 543 */ 544 func (peer *Peer) BeginSymmetricSession() error { 545 device := peer.device 546 handshake := &peer.handshake 547 handshake.mutex.Lock() 548 defer handshake.mutex.Unlock() 549 550 // derive keys 551 552 var isInitiator bool 553 var sendKey [chacha20poly1305.KeySize]byte 554 var recvKey [chacha20poly1305.KeySize]byte 555 556 if handshake.state == handshakeResponseConsumed { 557 KDF2( 558 &sendKey, 559 &recvKey, 560 handshake.chainKey[:], 561 nil, 562 ) 563 isInitiator = true 564 } else if handshake.state == handshakeResponseCreated { 565 KDF2( 566 &recvKey, 567 &sendKey, 568 handshake.chainKey[:], 569 nil, 570 ) 571 isInitiator = false 572 } else { 573 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 574 } 575 576 // zero handshake 577 578 setZero(handshake.chainKey[:]) 579 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 580 setZero(handshake.localEphemeral[:]) 581 peer.handshake.state = handshakeZeroed 582 583 // create AEAD instances 584 585 keypair := new(Keypair) 586 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 587 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 588 589 setZero(sendKey[:]) 590 setZero(recvKey[:]) 591 592 keypair.created = time.Now() 593 keypair.replayFilter.Reset() 594 keypair.isInitiator = isInitiator 595 keypair.localIndex = peer.handshake.localIndex 596 keypair.remoteIndex = peer.handshake.remoteIndex 597 598 // remap index 599 600 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 601 handshake.localIndex = 0 602 603 // rotate key pairs 604 605 keypairs := &peer.keypairs 606 keypairs.Lock() 607 defer keypairs.Unlock() 608 609 previous := keypairs.previous 610 next := keypairs.loadNext() 611 current := keypairs.current 612 613 if isInitiator { 614 if next != nil { 615 keypairs.storeNext(nil) 616 keypairs.previous = next 617 device.DeleteKeypair(current) 618 } else { 619 keypairs.previous = current 620 } 621 device.DeleteKeypair(previous) 622 keypairs.current = keypair 623 } else { 624 keypairs.storeNext(keypair) 625 device.DeleteKeypair(next) 626 keypairs.previous = nil 627 device.DeleteKeypair(previous) 628 } 629 630 return nil 631 } 632 633 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 634 keypairs := &peer.keypairs 635 636 if keypairs.loadNext() != receivedKeypair { 637 return false 638 } 639 keypairs.Lock() 640 defer keypairs.Unlock() 641 if keypairs.loadNext() != receivedKeypair { 642 return false 643 } 644 old := keypairs.previous 645 keypairs.previous = keypairs.current 646 peer.device.DeleteKeypair(old) 647 keypairs.current = keypairs.loadNext() 648 keypairs.storeNext(nil) 649 return true 650 }