github.com/GFW-knocker/wireguard@v1.0.1/device/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "golang.org/x/crypto/blake2s" 15 "golang.org/x/crypto/chacha20poly1305" 16 "golang.org/x/crypto/poly1305" 17 18 "github.com/GFW-knocker/wireguard/tai64n" 19 ) 20 21 type handshakeState int 22 23 const ( 24 handshakeZeroed = handshakeState(iota) 25 handshakeInitiationCreated 26 handshakeInitiationConsumed 27 handshakeResponseCreated 28 handshakeResponseConsumed 29 ) 30 31 func (hs handshakeState) String() string { 32 switch hs { 33 case handshakeZeroed: 34 return "handshakeZeroed" 35 case handshakeInitiationCreated: 36 return "handshakeInitiationCreated" 37 case handshakeInitiationConsumed: 38 return "handshakeInitiationConsumed" 39 case handshakeResponseCreated: 40 return "handshakeResponseCreated" 41 case handshakeResponseConsumed: 42 return "handshakeResponseConsumed" 43 default: 44 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 45 } 46 } 47 48 const ( 49 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 50 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 51 WGLabelMAC1 = "mac1----" 52 WGLabelCookie = "cookie--" 53 ) 54 55 const ( 56 MessageInitiationType = 1 57 MessageResponseType = 2 58 MessageCookieReplyType = 3 59 MessageTransportType = 4 60 ) 61 62 const ( 63 MessageInitiationSize = 148 // size of handshake initiation message 64 MessageResponseSize = 92 // size of response message 65 MessageCookieReplySize = 64 // size of cookie reply message 66 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 67 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 68 MessageKeepaliveSize = MessageTransportSize // size of keepalive 69 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 70 ) 71 72 const ( 73 MessageTransportOffsetReceiver = 4 74 MessageTransportOffsetCounter = 8 75 MessageTransportOffsetContent = 16 76 ) 77 78 /* Type is an 8-bit field, followed by 3 nul bytes, 79 * by marshalling the messages in little-endian byteorder 80 * we can treat these as a 32-bit unsigned int (for now) 81 * 82 */ 83 84 type MessageInitiation struct { 85 Type uint32 86 Sender uint32 87 Ephemeral NoisePublicKey 88 Static [NoisePublicKeySize + poly1305.TagSize]byte 89 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 90 MAC1 [blake2s.Size128]byte 91 MAC2 [blake2s.Size128]byte 92 } 93 94 type MessageResponse struct { 95 Type uint32 96 Sender uint32 97 Receiver uint32 98 Ephemeral NoisePublicKey 99 Empty [poly1305.TagSize]byte 100 MAC1 [blake2s.Size128]byte 101 MAC2 [blake2s.Size128]byte 102 } 103 104 type MessageTransport struct { 105 Type uint32 106 Receiver uint32 107 Counter uint64 108 Content []byte 109 } 110 111 type MessageCookieReply struct { 112 Type uint32 113 Receiver uint32 114 Nonce [chacha20poly1305.NonceSizeX]byte 115 Cookie [blake2s.Size128 + poly1305.TagSize]byte 116 } 117 118 type Handshake struct { 119 state handshakeState 120 mutex sync.RWMutex 121 hash [blake2s.Size]byte // hash value 122 chainKey [blake2s.Size]byte // chain key 123 presharedKey NoisePresharedKey // psk 124 localEphemeral NoisePrivateKey // ephemeral secret key 125 localIndex uint32 // used to clear hash-table 126 remoteIndex uint32 // index for sending 127 remoteStatic NoisePublicKey // long term key 128 remoteEphemeral NoisePublicKey // ephemeral public key 129 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 130 lastTimestamp tai64n.Timestamp 131 lastInitiationConsumption time.Time 132 lastSentHandshake time.Time 133 } 134 135 var ( 136 InitialChainKey [blake2s.Size]byte 137 InitialHash [blake2s.Size]byte 138 ZeroNonce [chacha20poly1305.NonceSize]byte 139 ) 140 141 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 142 KDF1(dst, c[:], data) 143 } 144 145 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 146 hash, _ := blake2s.New256(nil) 147 hash.Write(h[:]) 148 hash.Write(data) 149 hash.Sum(dst[:0]) 150 hash.Reset() 151 } 152 153 func (h *Handshake) Clear() { 154 setZero(h.localEphemeral[:]) 155 setZero(h.remoteEphemeral[:]) 156 setZero(h.chainKey[:]) 157 setZero(h.hash[:]) 158 h.localIndex = 0 159 h.state = handshakeZeroed 160 } 161 162 func (h *Handshake) mixHash(data []byte) { 163 mixHash(&h.hash, &h.hash, data) 164 } 165 166 func (h *Handshake) mixKey(data []byte) { 167 mixKey(&h.chainKey, &h.chainKey, data) 168 } 169 170 /* Do basic precomputations 171 */ 172 func init() { 173 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 174 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 175 } 176 177 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 178 device.staticIdentity.RLock() 179 defer device.staticIdentity.RUnlock() 180 181 handshake := &peer.handshake 182 handshake.mutex.Lock() 183 defer handshake.mutex.Unlock() 184 185 // create ephemeral key 186 var err error 187 handshake.hash = InitialHash 188 handshake.chainKey = InitialChainKey 189 handshake.localEphemeral, err = newPrivateKey() 190 if err != nil { 191 return nil, err 192 } 193 194 handshake.mixHash(handshake.remoteStatic[:]) 195 196 msg := MessageInitiation{ 197 Type: MessageInitiationType, 198 Ephemeral: handshake.localEphemeral.publicKey(), 199 } 200 201 handshake.mixKey(msg.Ephemeral[:]) 202 handshake.mixHash(msg.Ephemeral[:]) 203 204 // encrypt static key 205 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 206 if err != nil { 207 return nil, err 208 } 209 var key [chacha20poly1305.KeySize]byte 210 KDF2( 211 &handshake.chainKey, 212 &key, 213 handshake.chainKey[:], 214 ss[:], 215 ) 216 aead, _ := chacha20poly1305.New(key[:]) 217 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 218 handshake.mixHash(msg.Static[:]) 219 220 // encrypt timestamp 221 if isZero(handshake.precomputedStaticStatic[:]) { 222 return nil, errInvalidPublicKey 223 } 224 KDF2( 225 &handshake.chainKey, 226 &key, 227 handshake.chainKey[:], 228 handshake.precomputedStaticStatic[:], 229 ) 230 timestamp := tai64n.Now() 231 aead, _ = chacha20poly1305.New(key[:]) 232 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 233 234 // assign index 235 device.indexTable.Delete(handshake.localIndex) 236 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 237 if err != nil { 238 return nil, err 239 } 240 handshake.localIndex = msg.Sender 241 242 handshake.mixHash(msg.Timestamp[:]) 243 handshake.state = handshakeInitiationCreated 244 return &msg, nil 245 } 246 247 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 248 var ( 249 hash [blake2s.Size]byte 250 chainKey [blake2s.Size]byte 251 ) 252 253 if msg.Type != MessageInitiationType { 254 return nil 255 } 256 257 device.staticIdentity.RLock() 258 defer device.staticIdentity.RUnlock() 259 260 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 261 mixHash(&hash, &hash, msg.Ephemeral[:]) 262 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 263 264 // decrypt static key 265 var peerPK NoisePublicKey 266 var key [chacha20poly1305.KeySize]byte 267 ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 268 if err != nil { 269 return nil 270 } 271 KDF2(&chainKey, &key, chainKey[:], ss[:]) 272 aead, _ := chacha20poly1305.New(key[:]) 273 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 274 if err != nil { 275 return nil 276 } 277 mixHash(&hash, &hash, msg.Static[:]) 278 279 // lookup peer 280 281 peer := device.LookupPeer(peerPK) 282 if peer == nil || !peer.isRunning.Load() { 283 return nil 284 } 285 286 handshake := &peer.handshake 287 288 // verify identity 289 290 var timestamp tai64n.Timestamp 291 292 handshake.mutex.RLock() 293 294 if isZero(handshake.precomputedStaticStatic[:]) { 295 handshake.mutex.RUnlock() 296 return nil 297 } 298 KDF2( 299 &chainKey, 300 &key, 301 chainKey[:], 302 handshake.precomputedStaticStatic[:], 303 ) 304 aead, _ = chacha20poly1305.New(key[:]) 305 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 306 if err != nil { 307 handshake.mutex.RUnlock() 308 return nil 309 } 310 mixHash(&hash, &hash, msg.Timestamp[:]) 311 312 // protect against replay & flood 313 314 replay := !timestamp.After(handshake.lastTimestamp) 315 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 316 handshake.mutex.RUnlock() 317 if replay { 318 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 319 return nil 320 } 321 if flood { 322 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 323 return nil 324 } 325 326 // update handshake state 327 328 handshake.mutex.Lock() 329 330 handshake.hash = hash 331 handshake.chainKey = chainKey 332 handshake.remoteIndex = msg.Sender 333 handshake.remoteEphemeral = msg.Ephemeral 334 if timestamp.After(handshake.lastTimestamp) { 335 handshake.lastTimestamp = timestamp 336 } 337 now := time.Now() 338 if now.After(handshake.lastInitiationConsumption) { 339 handshake.lastInitiationConsumption = now 340 } 341 handshake.state = handshakeInitiationConsumed 342 343 handshake.mutex.Unlock() 344 345 setZero(hash[:]) 346 setZero(chainKey[:]) 347 348 return peer 349 } 350 351 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 352 handshake := &peer.handshake 353 handshake.mutex.Lock() 354 defer handshake.mutex.Unlock() 355 356 if handshake.state != handshakeInitiationConsumed { 357 return nil, errors.New("handshake initiation must be consumed first") 358 } 359 360 // assign index 361 362 var err error 363 device.indexTable.Delete(handshake.localIndex) 364 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 365 if err != nil { 366 return nil, err 367 } 368 369 var msg MessageResponse 370 msg.Type = MessageResponseType 371 msg.Sender = handshake.localIndex 372 msg.Receiver = handshake.remoteIndex 373 374 // create ephemeral key 375 376 handshake.localEphemeral, err = newPrivateKey() 377 if err != nil { 378 return nil, err 379 } 380 msg.Ephemeral = handshake.localEphemeral.publicKey() 381 handshake.mixHash(msg.Ephemeral[:]) 382 handshake.mixKey(msg.Ephemeral[:]) 383 384 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 385 if err != nil { 386 return nil, err 387 } 388 handshake.mixKey(ss[:]) 389 ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 390 if err != nil { 391 return nil, err 392 } 393 handshake.mixKey(ss[:]) 394 395 // add preshared key 396 397 var tau [blake2s.Size]byte 398 var key [chacha20poly1305.KeySize]byte 399 400 KDF3( 401 &handshake.chainKey, 402 &tau, 403 &key, 404 handshake.chainKey[:], 405 handshake.presharedKey[:], 406 ) 407 408 handshake.mixHash(tau[:]) 409 410 aead, _ := chacha20poly1305.New(key[:]) 411 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 412 handshake.mixHash(msg.Empty[:]) 413 414 handshake.state = handshakeResponseCreated 415 416 return &msg, nil 417 } 418 419 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 420 if msg.Type != MessageResponseType { 421 return nil 422 } 423 424 // lookup handshake by receiver 425 426 lookup := device.indexTable.Lookup(msg.Receiver) 427 handshake := lookup.handshake 428 if handshake == nil { 429 return nil 430 } 431 432 var ( 433 hash [blake2s.Size]byte 434 chainKey [blake2s.Size]byte 435 ) 436 437 ok := func() bool { 438 // lock handshake state 439 440 handshake.mutex.RLock() 441 defer handshake.mutex.RUnlock() 442 443 if handshake.state != handshakeInitiationCreated { 444 return false 445 } 446 447 // lock private key for reading 448 449 device.staticIdentity.RLock() 450 defer device.staticIdentity.RUnlock() 451 452 // finish 3-way DH 453 454 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 455 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 456 457 ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 458 if err != nil { 459 return false 460 } 461 mixKey(&chainKey, &chainKey, ss[:]) 462 setZero(ss[:]) 463 464 ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 465 if err != nil { 466 return false 467 } 468 mixKey(&chainKey, &chainKey, ss[:]) 469 setZero(ss[:]) 470 471 // add preshared key (psk) 472 473 var tau [blake2s.Size]byte 474 var key [chacha20poly1305.KeySize]byte 475 KDF3( 476 &chainKey, 477 &tau, 478 &key, 479 chainKey[:], 480 handshake.presharedKey[:], 481 ) 482 mixHash(&hash, &hash, tau[:]) 483 484 // authenticate transcript 485 486 aead, _ := chacha20poly1305.New(key[:]) 487 _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 488 if err != nil { 489 return false 490 } 491 mixHash(&hash, &hash, msg.Empty[:]) 492 return true 493 }() 494 495 if !ok { 496 return nil 497 } 498 499 // update handshake state 500 501 handshake.mutex.Lock() 502 503 handshake.hash = hash 504 handshake.chainKey = chainKey 505 handshake.remoteIndex = msg.Sender 506 handshake.state = handshakeResponseConsumed 507 508 handshake.mutex.Unlock() 509 510 setZero(hash[:]) 511 setZero(chainKey[:]) 512 513 return lookup.peer 514 } 515 516 /* Derives a new keypair from the current handshake state 517 * 518 */ 519 func (peer *Peer) BeginSymmetricSession() error { 520 device := peer.device 521 handshake := &peer.handshake 522 handshake.mutex.Lock() 523 defer handshake.mutex.Unlock() 524 525 // derive keys 526 527 var isInitiator bool 528 var sendKey [chacha20poly1305.KeySize]byte 529 var recvKey [chacha20poly1305.KeySize]byte 530 531 if handshake.state == handshakeResponseConsumed { 532 KDF2( 533 &sendKey, 534 &recvKey, 535 handshake.chainKey[:], 536 nil, 537 ) 538 isInitiator = true 539 } else if handshake.state == handshakeResponseCreated { 540 KDF2( 541 &recvKey, 542 &sendKey, 543 handshake.chainKey[:], 544 nil, 545 ) 546 isInitiator = false 547 } else { 548 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 549 } 550 551 // zero handshake 552 553 setZero(handshake.chainKey[:]) 554 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 555 setZero(handshake.localEphemeral[:]) 556 peer.handshake.state = handshakeZeroed 557 558 // create AEAD instances 559 560 keypair := new(Keypair) 561 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 562 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 563 564 setZero(sendKey[:]) 565 setZero(recvKey[:]) 566 567 keypair.created = time.Now() 568 keypair.replayFilter.Reset() 569 keypair.isInitiator = isInitiator 570 keypair.localIndex = peer.handshake.localIndex 571 keypair.remoteIndex = peer.handshake.remoteIndex 572 573 // remap index 574 575 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 576 handshake.localIndex = 0 577 578 // rotate key pairs 579 580 keypairs := &peer.keypairs 581 keypairs.Lock() 582 defer keypairs.Unlock() 583 584 previous := keypairs.previous 585 next := keypairs.next.Load() 586 current := keypairs.current 587 588 if isInitiator { 589 if next != nil { 590 keypairs.next.Store(nil) 591 keypairs.previous = next 592 device.DeleteKeypair(current) 593 } else { 594 keypairs.previous = current 595 } 596 device.DeleteKeypair(previous) 597 keypairs.current = keypair 598 } else { 599 keypairs.next.Store(keypair) 600 device.DeleteKeypair(next) 601 keypairs.previous = nil 602 device.DeleteKeypair(previous) 603 } 604 605 return nil 606 } 607 608 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 609 keypairs := &peer.keypairs 610 611 if keypairs.next.Load() != receivedKeypair { 612 return false 613 } 614 keypairs.Lock() 615 defer keypairs.Unlock() 616 if keypairs.next.Load() != receivedKeypair { 617 return false 618 } 619 old := keypairs.previous 620 keypairs.previous = keypairs.current 621 peer.device.DeleteKeypair(old) 622 keypairs.current = keypairs.next.Load() 623 keypairs.next.Store(nil) 624 return true 625 }