github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "golang.org/x/crypto/blake2s" 15 "golang.org/x/crypto/chacha20poly1305" 16 "golang.org/x/crypto/poly1305" 17 18 "github.com/amnezia-vpn/amneziawg-go/tai64n" 19 ) 20 21 type handshakeState int 22 23 const ( 24 handshakeZeroed = handshakeState(iota) 25 handshakeInitiationCreated 26 handshakeInitiationConsumed 27 handshakeResponseCreated 28 handshakeResponseConsumed 29 ) 30 31 func (hs handshakeState) String() string { 32 switch hs { 33 case handshakeZeroed: 34 return "handshakeZeroed" 35 case handshakeInitiationCreated: 36 return "handshakeInitiationCreated" 37 case handshakeInitiationConsumed: 38 return "handshakeInitiationConsumed" 39 case handshakeResponseCreated: 40 return "handshakeResponseCreated" 41 case handshakeResponseConsumed: 42 return "handshakeResponseConsumed" 43 default: 44 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 45 } 46 } 47 48 const ( 49 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 50 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 51 WGLabelMAC1 = "mac1----" 52 WGLabelCookie = "cookie--" 53 ) 54 55 var ( 56 MessageInitiationType uint32 = 1 57 MessageResponseType uint32 = 2 58 MessageCookieReplyType uint32 = 3 59 MessageTransportType uint32 = 4 60 ) 61 62 const ( 63 MessageInitiationSize = 148 // size of handshake initiation message 64 MessageResponseSize = 92 // size of response message 65 MessageCookieReplySize = 64 // size of cookie reply message 66 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 67 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 68 MessageKeepaliveSize = MessageTransportSize // size of keepalive 69 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 70 ) 71 72 const ( 73 MessageTransportOffsetReceiver = 4 74 MessageTransportOffsetCounter = 8 75 MessageTransportOffsetContent = 16 76 ) 77 78 var packetSizeToMsgType map[int]uint32 79 80 var msgTypeToJunkSize map[uint32]int 81 82 /* Type is an 8-bit field, followed by 3 nul bytes, 83 * by marshalling the messages in little-endian byteorder 84 * we can treat these as a 32-bit unsigned int (for now) 85 * 86 */ 87 88 type MessageInitiation struct { 89 Type uint32 90 Sender uint32 91 Ephemeral NoisePublicKey 92 Static [NoisePublicKeySize + poly1305.TagSize]byte 93 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 94 MAC1 [blake2s.Size128]byte 95 MAC2 [blake2s.Size128]byte 96 } 97 98 type MessageResponse struct { 99 Type uint32 100 Sender uint32 101 Receiver uint32 102 Ephemeral NoisePublicKey 103 Empty [poly1305.TagSize]byte 104 MAC1 [blake2s.Size128]byte 105 MAC2 [blake2s.Size128]byte 106 } 107 108 type MessageTransport struct { 109 Type uint32 110 Receiver uint32 111 Counter uint64 112 Content []byte 113 } 114 115 type MessageCookieReply struct { 116 Type uint32 117 Receiver uint32 118 Nonce [chacha20poly1305.NonceSizeX]byte 119 Cookie [blake2s.Size128 + poly1305.TagSize]byte 120 } 121 122 type Handshake struct { 123 state handshakeState 124 mutex sync.RWMutex 125 hash [blake2s.Size]byte // hash value 126 chainKey [blake2s.Size]byte // chain key 127 presharedKey NoisePresharedKey // psk 128 localEphemeral NoisePrivateKey // ephemeral secret key 129 localIndex uint32 // used to clear hash-table 130 remoteIndex uint32 // index for sending 131 remoteStatic NoisePublicKey // long term key 132 remoteEphemeral NoisePublicKey // ephemeral public key 133 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 134 lastTimestamp tai64n.Timestamp 135 lastInitiationConsumption time.Time 136 lastSentHandshake time.Time 137 } 138 139 var ( 140 InitialChainKey [blake2s.Size]byte 141 InitialHash [blake2s.Size]byte 142 ZeroNonce [chacha20poly1305.NonceSize]byte 143 ) 144 145 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 146 KDF1(dst, c[:], data) 147 } 148 149 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 150 hash, _ := blake2s.New256(nil) 151 hash.Write(h[:]) 152 hash.Write(data) 153 hash.Sum(dst[:0]) 154 hash.Reset() 155 } 156 157 func (h *Handshake) Clear() { 158 setZero(h.localEphemeral[:]) 159 setZero(h.remoteEphemeral[:]) 160 setZero(h.chainKey[:]) 161 setZero(h.hash[:]) 162 h.localIndex = 0 163 h.state = handshakeZeroed 164 } 165 166 func (h *Handshake) mixHash(data []byte) { 167 mixHash(&h.hash, &h.hash, data) 168 } 169 170 func (h *Handshake) mixKey(data []byte) { 171 mixKey(&h.chainKey, &h.chainKey, data) 172 } 173 174 /* Do basic precomputations 175 */ 176 func init() { 177 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 178 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 179 } 180 181 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 182 device.staticIdentity.RLock() 183 defer device.staticIdentity.RUnlock() 184 185 handshake := &peer.handshake 186 handshake.mutex.Lock() 187 defer handshake.mutex.Unlock() 188 189 // create ephemeral key 190 var err error 191 handshake.hash = InitialHash 192 handshake.chainKey = InitialChainKey 193 handshake.localEphemeral, err = newPrivateKey() 194 if err != nil { 195 return nil, err 196 } 197 198 handshake.mixHash(handshake.remoteStatic[:]) 199 200 device.aSecMux.RLock() 201 msg := MessageInitiation{ 202 Type: MessageInitiationType, 203 Ephemeral: handshake.localEphemeral.publicKey(), 204 } 205 device.aSecMux.RUnlock() 206 207 handshake.mixKey(msg.Ephemeral[:]) 208 handshake.mixHash(msg.Ephemeral[:]) 209 210 // encrypt static key 211 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 212 if err != nil { 213 return nil, err 214 } 215 var key [chacha20poly1305.KeySize]byte 216 KDF2( 217 &handshake.chainKey, 218 &key, 219 handshake.chainKey[:], 220 ss[:], 221 ) 222 aead, _ := chacha20poly1305.New(key[:]) 223 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 224 handshake.mixHash(msg.Static[:]) 225 226 // encrypt timestamp 227 if isZero(handshake.precomputedStaticStatic[:]) { 228 return nil, errInvalidPublicKey 229 } 230 KDF2( 231 &handshake.chainKey, 232 &key, 233 handshake.chainKey[:], 234 handshake.precomputedStaticStatic[:], 235 ) 236 timestamp := tai64n.Now() 237 aead, _ = chacha20poly1305.New(key[:]) 238 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 239 240 // assign index 241 device.indexTable.Delete(handshake.localIndex) 242 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 243 if err != nil { 244 return nil, err 245 } 246 handshake.localIndex = msg.Sender 247 248 handshake.mixHash(msg.Timestamp[:]) 249 handshake.state = handshakeInitiationCreated 250 return &msg, nil 251 } 252 253 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 254 var ( 255 hash [blake2s.Size]byte 256 chainKey [blake2s.Size]byte 257 ) 258 259 device.aSecMux.RLock() 260 if msg.Type != MessageInitiationType { 261 device.aSecMux.RUnlock() 262 return nil 263 } 264 device.aSecMux.RUnlock() 265 266 device.staticIdentity.RLock() 267 defer device.staticIdentity.RUnlock() 268 269 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 270 mixHash(&hash, &hash, msg.Ephemeral[:]) 271 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 272 273 // decrypt static key 274 var peerPK NoisePublicKey 275 var key [chacha20poly1305.KeySize]byte 276 ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 277 if err != nil { 278 return nil 279 } 280 KDF2(&chainKey, &key, chainKey[:], ss[:]) 281 aead, _ := chacha20poly1305.New(key[:]) 282 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 283 if err != nil { 284 return nil 285 } 286 mixHash(&hash, &hash, msg.Static[:]) 287 288 // lookup peer 289 290 peer := device.LookupPeer(peerPK) 291 if peer == nil || !peer.isRunning.Load() { 292 return nil 293 } 294 295 handshake := &peer.handshake 296 297 // verify identity 298 299 var timestamp tai64n.Timestamp 300 301 handshake.mutex.RLock() 302 303 if isZero(handshake.precomputedStaticStatic[:]) { 304 handshake.mutex.RUnlock() 305 return nil 306 } 307 KDF2( 308 &chainKey, 309 &key, 310 chainKey[:], 311 handshake.precomputedStaticStatic[:], 312 ) 313 aead, _ = chacha20poly1305.New(key[:]) 314 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 315 if err != nil { 316 handshake.mutex.RUnlock() 317 return nil 318 } 319 mixHash(&hash, &hash, msg.Timestamp[:]) 320 321 // protect against replay & flood 322 323 replay := !timestamp.After(handshake.lastTimestamp) 324 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 325 handshake.mutex.RUnlock() 326 if replay { 327 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 328 return nil 329 } 330 if flood { 331 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 332 return nil 333 } 334 335 // update handshake state 336 337 handshake.mutex.Lock() 338 339 handshake.hash = hash 340 handshake.chainKey = chainKey 341 handshake.remoteIndex = msg.Sender 342 handshake.remoteEphemeral = msg.Ephemeral 343 if timestamp.After(handshake.lastTimestamp) { 344 handshake.lastTimestamp = timestamp 345 } 346 now := time.Now() 347 if now.After(handshake.lastInitiationConsumption) { 348 handshake.lastInitiationConsumption = now 349 } 350 handshake.state = handshakeInitiationConsumed 351 352 handshake.mutex.Unlock() 353 354 setZero(hash[:]) 355 setZero(chainKey[:]) 356 357 return peer 358 } 359 360 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 361 handshake := &peer.handshake 362 handshake.mutex.Lock() 363 defer handshake.mutex.Unlock() 364 365 if handshake.state != handshakeInitiationConsumed { 366 return nil, errors.New("handshake initiation must be consumed first") 367 } 368 369 // assign index 370 371 var err error 372 device.indexTable.Delete(handshake.localIndex) 373 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 374 if err != nil { 375 return nil, err 376 } 377 378 var msg MessageResponse 379 device.aSecMux.RLock() 380 msg.Type = MessageResponseType 381 device.aSecMux.RUnlock() 382 msg.Sender = handshake.localIndex 383 msg.Receiver = handshake.remoteIndex 384 385 // create ephemeral key 386 387 handshake.localEphemeral, err = newPrivateKey() 388 if err != nil { 389 return nil, err 390 } 391 msg.Ephemeral = handshake.localEphemeral.publicKey() 392 handshake.mixHash(msg.Ephemeral[:]) 393 handshake.mixKey(msg.Ephemeral[:]) 394 395 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 396 if err != nil { 397 return nil, err 398 } 399 handshake.mixKey(ss[:]) 400 ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 401 if err != nil { 402 return nil, err 403 } 404 handshake.mixKey(ss[:]) 405 406 // add preshared key 407 408 var tau [blake2s.Size]byte 409 var key [chacha20poly1305.KeySize]byte 410 411 KDF3( 412 &handshake.chainKey, 413 &tau, 414 &key, 415 handshake.chainKey[:], 416 handshake.presharedKey[:], 417 ) 418 419 handshake.mixHash(tau[:]) 420 421 aead, _ := chacha20poly1305.New(key[:]) 422 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 423 handshake.mixHash(msg.Empty[:]) 424 425 handshake.state = handshakeResponseCreated 426 427 return &msg, nil 428 } 429 430 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 431 device.aSecMux.RLock() 432 if msg.Type != MessageResponseType { 433 device.aSecMux.RUnlock() 434 return nil 435 } 436 device.aSecMux.RUnlock() 437 438 // lookup handshake by receiver 439 440 lookup := device.indexTable.Lookup(msg.Receiver) 441 handshake := lookup.handshake 442 if handshake == nil { 443 return nil 444 } 445 446 var ( 447 hash [blake2s.Size]byte 448 chainKey [blake2s.Size]byte 449 ) 450 451 ok := func() bool { 452 // lock handshake state 453 454 handshake.mutex.RLock() 455 defer handshake.mutex.RUnlock() 456 457 if handshake.state != handshakeInitiationCreated { 458 return false 459 } 460 461 // lock private key for reading 462 463 device.staticIdentity.RLock() 464 defer device.staticIdentity.RUnlock() 465 466 // finish 3-way DH 467 468 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 469 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 470 471 ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 472 if err != nil { 473 return false 474 } 475 mixKey(&chainKey, &chainKey, ss[:]) 476 setZero(ss[:]) 477 478 ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 479 if err != nil { 480 return false 481 } 482 mixKey(&chainKey, &chainKey, ss[:]) 483 setZero(ss[:]) 484 485 // add preshared key (psk) 486 487 var tau [blake2s.Size]byte 488 var key [chacha20poly1305.KeySize]byte 489 KDF3( 490 &chainKey, 491 &tau, 492 &key, 493 chainKey[:], 494 handshake.presharedKey[:], 495 ) 496 mixHash(&hash, &hash, tau[:]) 497 498 // authenticate transcript 499 500 aead, _ := chacha20poly1305.New(key[:]) 501 _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 502 if err != nil { 503 return false 504 } 505 mixHash(&hash, &hash, msg.Empty[:]) 506 return true 507 }() 508 509 if !ok { 510 return nil 511 } 512 513 // update handshake state 514 515 handshake.mutex.Lock() 516 517 handshake.hash = hash 518 handshake.chainKey = chainKey 519 handshake.remoteIndex = msg.Sender 520 handshake.state = handshakeResponseConsumed 521 522 handshake.mutex.Unlock() 523 524 setZero(hash[:]) 525 setZero(chainKey[:]) 526 527 return lookup.peer 528 } 529 530 /* Derives a new keypair from the current handshake state 531 * 532 */ 533 func (peer *Peer) BeginSymmetricSession() error { 534 device := peer.device 535 handshake := &peer.handshake 536 handshake.mutex.Lock() 537 defer handshake.mutex.Unlock() 538 539 // derive keys 540 541 var isInitiator bool 542 var sendKey [chacha20poly1305.KeySize]byte 543 var recvKey [chacha20poly1305.KeySize]byte 544 545 if handshake.state == handshakeResponseConsumed { 546 KDF2( 547 &sendKey, 548 &recvKey, 549 handshake.chainKey[:], 550 nil, 551 ) 552 isInitiator = true 553 } else if handshake.state == handshakeResponseCreated { 554 KDF2( 555 &recvKey, 556 &sendKey, 557 handshake.chainKey[:], 558 nil, 559 ) 560 isInitiator = false 561 } else { 562 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 563 } 564 565 // zero handshake 566 567 setZero(handshake.chainKey[:]) 568 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 569 setZero(handshake.localEphemeral[:]) 570 peer.handshake.state = handshakeZeroed 571 572 // create AEAD instances 573 574 keypair := new(Keypair) 575 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 576 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 577 578 setZero(sendKey[:]) 579 setZero(recvKey[:]) 580 581 keypair.created = time.Now() 582 keypair.replayFilter.Reset() 583 keypair.isInitiator = isInitiator 584 keypair.localIndex = peer.handshake.localIndex 585 keypair.remoteIndex = peer.handshake.remoteIndex 586 587 // remap index 588 589 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 590 handshake.localIndex = 0 591 592 // rotate key pairs 593 594 keypairs := &peer.keypairs 595 keypairs.Lock() 596 defer keypairs.Unlock() 597 598 previous := keypairs.previous 599 next := keypairs.next.Load() 600 current := keypairs.current 601 602 if isInitiator { 603 if next != nil { 604 keypairs.next.Store(nil) 605 keypairs.previous = next 606 device.DeleteKeypair(current) 607 } else { 608 keypairs.previous = current 609 } 610 device.DeleteKeypair(previous) 611 keypairs.current = keypair 612 } else { 613 keypairs.next.Store(keypair) 614 device.DeleteKeypair(next) 615 keypairs.previous = nil 616 device.DeleteKeypair(previous) 617 } 618 619 return nil 620 } 621 622 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 623 keypairs := &peer.keypairs 624 625 if keypairs.next.Load() != receivedKeypair { 626 return false 627 } 628 keypairs.Lock() 629 defer keypairs.Unlock() 630 if keypairs.next.Load() != receivedKeypair { 631 return false 632 } 633 old := keypairs.previous 634 keypairs.previous = keypairs.current 635 peer.device.DeleteKeypair(old) 636 keypairs.current = keypairs.next.Load() 637 keypairs.next.Store(nil) 638 return true 639 }