github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "github.com/sagernet/wireguard-go/tai64n" 15 "golang.org/x/crypto/blake2s" 16 "golang.org/x/crypto/chacha20poly1305" 17 "golang.org/x/crypto/poly1305" 18 ) 19 20 type handshakeState int 21 22 const ( 23 handshakeZeroed = handshakeState(iota) 24 handshakeInitiationCreated 25 handshakeInitiationConsumed 26 handshakeResponseCreated 27 handshakeResponseConsumed 28 ) 29 30 func (hs handshakeState) String() string { 31 switch hs { 32 case handshakeZeroed: 33 return "handshakeZeroed" 34 case handshakeInitiationCreated: 35 return "handshakeInitiationCreated" 36 case handshakeInitiationConsumed: 37 return "handshakeInitiationConsumed" 38 case handshakeResponseCreated: 39 return "handshakeResponseCreated" 40 case handshakeResponseConsumed: 41 return "handshakeResponseConsumed" 42 default: 43 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 44 } 45 } 46 47 const ( 48 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 49 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 50 WGLabelMAC1 = "mac1----" 51 WGLabelCookie = "cookie--" 52 ) 53 54 const ( 55 MessageInitiationType = 1 56 MessageResponseType = 2 57 MessageCookieReplyType = 3 58 MessageTransportType = 4 59 ) 60 61 const ( 62 MessageInitiationSize = 148 // size of handshake initiation message 63 MessageResponseSize = 92 // size of response message 64 MessageCookieReplySize = 64 // size of cookie reply message 65 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 66 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 67 MessageKeepaliveSize = MessageTransportSize // size of keepalive 68 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 69 ) 70 71 const ( 72 MessageTransportOffsetReceiver = 4 73 MessageTransportOffsetCounter = 8 74 MessageTransportOffsetContent = 16 75 ) 76 77 /* Type is an 8-bit field, followed by 3 nul bytes, 78 * by marshalling the messages in little-endian byteorder 79 * we can treat these as a 32-bit unsigned int (for now) 80 * 81 */ 82 83 type MessageInitiation struct { 84 Type uint32 85 Sender uint32 86 Ephemeral NoisePublicKey 87 Static [NoisePublicKeySize + poly1305.TagSize]byte 88 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 89 MAC1 [blake2s.Size128]byte 90 MAC2 [blake2s.Size128]byte 91 } 92 93 type MessageResponse struct { 94 Type uint32 95 Sender uint32 96 Receiver uint32 97 Ephemeral NoisePublicKey 98 Empty [poly1305.TagSize]byte 99 MAC1 [blake2s.Size128]byte 100 MAC2 [blake2s.Size128]byte 101 } 102 103 type MessageTransport struct { 104 Type uint32 105 Receiver uint32 106 Counter uint64 107 Content []byte 108 } 109 110 type MessageCookieReply struct { 111 Type uint32 112 Receiver uint32 113 Nonce [chacha20poly1305.NonceSizeX]byte 114 Cookie [blake2s.Size128 + poly1305.TagSize]byte 115 } 116 117 type Handshake struct { 118 state handshakeState 119 mutex sync.RWMutex 120 hash [blake2s.Size]byte // hash value 121 chainKey [blake2s.Size]byte // chain key 122 presharedKey NoisePresharedKey // psk 123 localEphemeral NoisePrivateKey // ephemeral secret key 124 localIndex uint32 // used to clear hash-table 125 remoteIndex uint32 // index for sending 126 remoteStatic NoisePublicKey // long term key 127 remoteEphemeral NoisePublicKey // ephemeral public key 128 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 129 lastTimestamp tai64n.Timestamp 130 lastInitiationConsumption time.Time 131 lastSentHandshake time.Time 132 } 133 134 var ( 135 InitialChainKey [blake2s.Size]byte 136 InitialHash [blake2s.Size]byte 137 ZeroNonce [chacha20poly1305.NonceSize]byte 138 ) 139 140 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 141 KDF1(dst, c[:], data) 142 } 143 144 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 145 hash, _ := blake2s.New256(nil) 146 hash.Write(h[:]) 147 hash.Write(data) 148 hash.Sum(dst[:0]) 149 hash.Reset() 150 } 151 152 func (h *Handshake) Clear() { 153 setZero(h.localEphemeral[:]) 154 setZero(h.remoteEphemeral[:]) 155 setZero(h.chainKey[:]) 156 setZero(h.hash[:]) 157 h.localIndex = 0 158 h.state = handshakeZeroed 159 } 160 161 func (h *Handshake) mixHash(data []byte) { 162 mixHash(&h.hash, &h.hash, data) 163 } 164 165 func (h *Handshake) mixKey(data []byte) { 166 mixKey(&h.chainKey, &h.chainKey, data) 167 } 168 169 /* Do basic precomputations 170 */ 171 func init() { 172 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 173 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 174 } 175 176 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 177 device.staticIdentity.RLock() 178 defer device.staticIdentity.RUnlock() 179 180 handshake := &peer.handshake 181 handshake.mutex.Lock() 182 defer handshake.mutex.Unlock() 183 184 // create ephemeral key 185 var err error 186 handshake.hash = InitialHash 187 handshake.chainKey = InitialChainKey 188 handshake.localEphemeral, err = newPrivateKey() 189 if err != nil { 190 return nil, err 191 } 192 193 handshake.mixHash(handshake.remoteStatic[:]) 194 195 msg := MessageInitiation{ 196 Type: MessageInitiationType, 197 Ephemeral: handshake.localEphemeral.publicKey(), 198 } 199 200 handshake.mixKey(msg.Ephemeral[:]) 201 handshake.mixHash(msg.Ephemeral[:]) 202 203 // encrypt static key 204 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 205 if err != nil { 206 return nil, err 207 } 208 var key [chacha20poly1305.KeySize]byte 209 KDF2( 210 &handshake.chainKey, 211 &key, 212 handshake.chainKey[:], 213 ss[:], 214 ) 215 aead, _ := chacha20poly1305.New(key[:]) 216 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 217 handshake.mixHash(msg.Static[:]) 218 219 // encrypt timestamp 220 if isZero(handshake.precomputedStaticStatic[:]) { 221 return nil, errInvalidPublicKey 222 } 223 KDF2( 224 &handshake.chainKey, 225 &key, 226 handshake.chainKey[:], 227 handshake.precomputedStaticStatic[:], 228 ) 229 timestamp := tai64n.Now() 230 aead, _ = chacha20poly1305.New(key[:]) 231 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 232 233 // assign index 234 device.indexTable.Delete(handshake.localIndex) 235 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 236 if err != nil { 237 return nil, err 238 } 239 handshake.localIndex = msg.Sender 240 241 handshake.mixHash(msg.Timestamp[:]) 242 handshake.state = handshakeInitiationCreated 243 return &msg, nil 244 } 245 246 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 247 var ( 248 hash [blake2s.Size]byte 249 chainKey [blake2s.Size]byte 250 ) 251 252 if msg.Type != MessageInitiationType { 253 return nil 254 } 255 256 device.staticIdentity.RLock() 257 defer device.staticIdentity.RUnlock() 258 259 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 260 mixHash(&hash, &hash, msg.Ephemeral[:]) 261 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 262 263 // decrypt static key 264 var peerPK NoisePublicKey 265 var key [chacha20poly1305.KeySize]byte 266 ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 267 if err != nil { 268 return nil 269 } 270 KDF2(&chainKey, &key, chainKey[:], ss[:]) 271 aead, _ := chacha20poly1305.New(key[:]) 272 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 273 if err != nil { 274 return nil 275 } 276 mixHash(&hash, &hash, msg.Static[:]) 277 278 // lookup peer 279 280 peer := device.LookupPeer(peerPK) 281 if peer == nil || !peer.isRunning.Load() { 282 return nil 283 } 284 285 handshake := &peer.handshake 286 287 // verify identity 288 289 var timestamp tai64n.Timestamp 290 291 handshake.mutex.RLock() 292 293 if isZero(handshake.precomputedStaticStatic[:]) { 294 handshake.mutex.RUnlock() 295 return nil 296 } 297 KDF2( 298 &chainKey, 299 &key, 300 chainKey[:], 301 handshake.precomputedStaticStatic[:], 302 ) 303 aead, _ = chacha20poly1305.New(key[:]) 304 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 305 if err != nil { 306 handshake.mutex.RUnlock() 307 return nil 308 } 309 mixHash(&hash, &hash, msg.Timestamp[:]) 310 311 // protect against replay & flood 312 313 replay := !timestamp.After(handshake.lastTimestamp) 314 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 315 handshake.mutex.RUnlock() 316 if replay { 317 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 318 return nil 319 } 320 if flood { 321 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 322 return nil 323 } 324 325 // update handshake state 326 327 handshake.mutex.Lock() 328 329 handshake.hash = hash 330 handshake.chainKey = chainKey 331 handshake.remoteIndex = msg.Sender 332 handshake.remoteEphemeral = msg.Ephemeral 333 if timestamp.After(handshake.lastTimestamp) { 334 handshake.lastTimestamp = timestamp 335 } 336 now := time.Now() 337 if now.After(handshake.lastInitiationConsumption) { 338 handshake.lastInitiationConsumption = now 339 } 340 handshake.state = handshakeInitiationConsumed 341 342 handshake.mutex.Unlock() 343 344 setZero(hash[:]) 345 setZero(chainKey[:]) 346 347 return peer 348 } 349 350 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 351 handshake := &peer.handshake 352 handshake.mutex.Lock() 353 defer handshake.mutex.Unlock() 354 355 if handshake.state != handshakeInitiationConsumed { 356 return nil, errors.New("handshake initiation must be consumed first") 357 } 358 359 // assign index 360 361 var err error 362 device.indexTable.Delete(handshake.localIndex) 363 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 364 if err != nil { 365 return nil, err 366 } 367 368 var msg MessageResponse 369 msg.Type = MessageResponseType 370 msg.Sender = handshake.localIndex 371 msg.Receiver = handshake.remoteIndex 372 373 // create ephemeral key 374 375 handshake.localEphemeral, err = newPrivateKey() 376 if err != nil { 377 return nil, err 378 } 379 msg.Ephemeral = handshake.localEphemeral.publicKey() 380 handshake.mixHash(msg.Ephemeral[:]) 381 handshake.mixKey(msg.Ephemeral[:]) 382 383 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 384 if err != nil { 385 return nil, err 386 } 387 handshake.mixKey(ss[:]) 388 ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 389 if err != nil { 390 return nil, err 391 } 392 handshake.mixKey(ss[:]) 393 394 // add preshared key 395 396 var tau [blake2s.Size]byte 397 var key [chacha20poly1305.KeySize]byte 398 399 KDF3( 400 &handshake.chainKey, 401 &tau, 402 &key, 403 handshake.chainKey[:], 404 handshake.presharedKey[:], 405 ) 406 407 handshake.mixHash(tau[:]) 408 409 aead, _ := chacha20poly1305.New(key[:]) 410 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 411 handshake.mixHash(msg.Empty[:]) 412 413 handshake.state = handshakeResponseCreated 414 415 return &msg, nil 416 } 417 418 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 419 if msg.Type != MessageResponseType { 420 return nil 421 } 422 423 // lookup handshake by receiver 424 425 lookup := device.indexTable.Lookup(msg.Receiver) 426 handshake := lookup.handshake 427 if handshake == nil { 428 return nil 429 } 430 431 var ( 432 hash [blake2s.Size]byte 433 chainKey [blake2s.Size]byte 434 ) 435 436 ok := func() bool { 437 // lock handshake state 438 439 handshake.mutex.RLock() 440 defer handshake.mutex.RUnlock() 441 442 if handshake.state != handshakeInitiationCreated { 443 return false 444 } 445 446 // lock private key for reading 447 448 device.staticIdentity.RLock() 449 defer device.staticIdentity.RUnlock() 450 451 // finish 3-way DH 452 453 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 454 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 455 456 ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 457 if err != nil { 458 return false 459 } 460 mixKey(&chainKey, &chainKey, ss[:]) 461 setZero(ss[:]) 462 463 ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 464 if err != nil { 465 return false 466 } 467 mixKey(&chainKey, &chainKey, ss[:]) 468 setZero(ss[:]) 469 470 // add preshared key (psk) 471 472 var tau [blake2s.Size]byte 473 var key [chacha20poly1305.KeySize]byte 474 KDF3( 475 &chainKey, 476 &tau, 477 &key, 478 chainKey[:], 479 handshake.presharedKey[:], 480 ) 481 mixHash(&hash, &hash, tau[:]) 482 483 // authenticate transcript 484 485 aead, _ := chacha20poly1305.New(key[:]) 486 _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 487 if err != nil { 488 return false 489 } 490 mixHash(&hash, &hash, msg.Empty[:]) 491 return true 492 }() 493 494 if !ok { 495 return nil 496 } 497 498 // update handshake state 499 500 handshake.mutex.Lock() 501 502 handshake.hash = hash 503 handshake.chainKey = chainKey 504 handshake.remoteIndex = msg.Sender 505 handshake.state = handshakeResponseConsumed 506 507 handshake.mutex.Unlock() 508 509 setZero(hash[:]) 510 setZero(chainKey[:]) 511 512 return lookup.peer 513 } 514 515 /* Derives a new keypair from the current handshake state 516 * 517 */ 518 func (peer *Peer) BeginSymmetricSession() error { 519 device := peer.device 520 handshake := &peer.handshake 521 handshake.mutex.Lock() 522 defer handshake.mutex.Unlock() 523 524 // derive keys 525 526 var isInitiator bool 527 var sendKey [chacha20poly1305.KeySize]byte 528 var recvKey [chacha20poly1305.KeySize]byte 529 530 if handshake.state == handshakeResponseConsumed { 531 KDF2( 532 &sendKey, 533 &recvKey, 534 handshake.chainKey[:], 535 nil, 536 ) 537 isInitiator = true 538 } else if handshake.state == handshakeResponseCreated { 539 KDF2( 540 &recvKey, 541 &sendKey, 542 handshake.chainKey[:], 543 nil, 544 ) 545 isInitiator = false 546 } else { 547 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 548 } 549 550 // zero handshake 551 552 setZero(handshake.chainKey[:]) 553 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 554 setZero(handshake.localEphemeral[:]) 555 peer.handshake.state = handshakeZeroed 556 557 // create AEAD instances 558 559 keypair := new(Keypair) 560 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 561 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 562 563 setZero(sendKey[:]) 564 setZero(recvKey[:]) 565 566 keypair.created = time.Now() 567 keypair.replayFilter.Reset() 568 keypair.isInitiator = isInitiator 569 keypair.localIndex = peer.handshake.localIndex 570 keypair.remoteIndex = peer.handshake.remoteIndex 571 572 // remap index 573 574 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 575 handshake.localIndex = 0 576 577 // rotate key pairs 578 579 keypairs := &peer.keypairs 580 keypairs.Lock() 581 defer keypairs.Unlock() 582 583 previous := keypairs.previous 584 next := keypairs.next.Load() 585 current := keypairs.current 586 587 if isInitiator { 588 if next != nil { 589 keypairs.next.Store(nil) 590 keypairs.previous = next 591 device.DeleteKeypair(current) 592 } else { 593 keypairs.previous = current 594 } 595 device.DeleteKeypair(previous) 596 keypairs.current = keypair 597 } else { 598 keypairs.next.Store(keypair) 599 device.DeleteKeypair(next) 600 keypairs.previous = nil 601 device.DeleteKeypair(previous) 602 } 603 604 return nil 605 } 606 607 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 608 keypairs := &peer.keypairs 609 610 if keypairs.next.Load() != receivedKeypair { 611 return false 612 } 613 keypairs.Lock() 614 defer keypairs.Unlock() 615 if keypairs.next.Load() != receivedKeypair { 616 return false 617 } 618 old := keypairs.previous 619 keypairs.previous = keypairs.current 620 peer.device.DeleteKeypair(old) 621 keypairs.current = keypairs.next.Load() 622 keypairs.next.Store(nil) 623 return true 624 }