github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "golang.org/x/crypto/blake2s" 15 "golang.org/x/crypto/chacha20poly1305" 16 "golang.org/x/crypto/poly1305" 17 18 "github.com/liloew/wireguard-go/tai64n" 19 ) 20 21 type handshakeState int 22 23 const ( 24 handshakeZeroed = handshakeState(iota) 25 handshakeInitiationCreated 26 handshakeInitiationConsumed 27 handshakeResponseCreated 28 handshakeResponseConsumed 29 ) 30 31 func (hs handshakeState) String() string { 32 switch hs { 33 case handshakeZeroed: 34 return "handshakeZeroed" 35 case handshakeInitiationCreated: 36 return "handshakeInitiationCreated" 37 case handshakeInitiationConsumed: 38 return "handshakeInitiationConsumed" 39 case handshakeResponseCreated: 40 return "handshakeResponseCreated" 41 case handshakeResponseConsumed: 42 return "handshakeResponseConsumed" 43 default: 44 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 45 } 46 } 47 48 const ( 49 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 50 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 51 WGLabelMAC1 = "mac1----" 52 WGLabelCookie = "cookie--" 53 ) 54 55 const ( 56 MessageInitiationType = 1 57 MessageResponseType = 2 58 MessageCookieReplyType = 3 59 MessageTransportType = 4 60 ) 61 62 const ( 63 MessageInitiationSize = 148 // size of handshake initiation message 64 MessageResponseSize = 92 // size of response message 65 MessageCookieReplySize = 64 // size of cookie reply message 66 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 67 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 68 MessageKeepaliveSize = MessageTransportSize // size of keepalive 69 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 70 ) 71 72 const ( 73 MessageTransportOffsetReceiver = 4 74 MessageTransportOffsetCounter = 8 75 MessageTransportOffsetContent = 16 76 ) 77 78 /* Type is an 8-bit field, followed by 3 nul bytes, 79 * by marshalling the messages in little-endian byteorder 80 * we can treat these as a 32-bit unsigned int (for now) 81 * 82 */ 83 84 type MessageInitiation struct { 85 Type uint32 86 Sender uint32 87 Ephemeral NoisePublicKey 88 Static [NoisePublicKeySize + poly1305.TagSize]byte 89 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 90 MAC1 [blake2s.Size128]byte 91 MAC2 [blake2s.Size128]byte 92 } 93 94 type MessageResponse struct { 95 Type uint32 96 Sender uint32 97 Receiver uint32 98 Ephemeral NoisePublicKey 99 Empty [poly1305.TagSize]byte 100 MAC1 [blake2s.Size128]byte 101 MAC2 [blake2s.Size128]byte 102 } 103 104 type MessageTransport struct { 105 Type uint32 106 Receiver uint32 107 Counter uint64 108 Content []byte 109 } 110 111 type MessageCookieReply struct { 112 Type uint32 113 Receiver uint32 114 Nonce [chacha20poly1305.NonceSizeX]byte 115 Cookie [blake2s.Size128 + poly1305.TagSize]byte 116 } 117 118 type Handshake struct { 119 state handshakeState 120 mutex sync.RWMutex 121 hash [blake2s.Size]byte // hash value 122 chainKey [blake2s.Size]byte // chain key 123 presharedKey NoisePresharedKey // psk 124 localEphemeral NoisePrivateKey // ephemeral secret key 125 localIndex uint32 // used to clear hash-table 126 remoteIndex uint32 // index for sending 127 remoteStatic NoisePublicKey // long term key 128 remoteEphemeral NoisePublicKey // ephemeral public key 129 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 130 lastTimestamp tai64n.Timestamp 131 lastInitiationConsumption time.Time 132 lastSentHandshake time.Time 133 } 134 135 var ( 136 InitialChainKey [blake2s.Size]byte 137 InitialHash [blake2s.Size]byte 138 ZeroNonce [chacha20poly1305.NonceSize]byte 139 ) 140 141 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 142 KDF1(dst, c[:], data) 143 } 144 145 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 146 hash, _ := blake2s.New256(nil) 147 hash.Write(h[:]) 148 hash.Write(data) 149 hash.Sum(dst[:0]) 150 hash.Reset() 151 } 152 153 func (h *Handshake) Clear() { 154 setZero(h.localEphemeral[:]) 155 setZero(h.remoteEphemeral[:]) 156 setZero(h.chainKey[:]) 157 setZero(h.hash[:]) 158 h.localIndex = 0 159 h.state = handshakeZeroed 160 } 161 162 func (h *Handshake) mixHash(data []byte) { 163 mixHash(&h.hash, &h.hash, data) 164 } 165 166 func (h *Handshake) mixKey(data []byte) { 167 mixKey(&h.chainKey, &h.chainKey, data) 168 } 169 170 /* Do basic precomputations 171 */ 172 func init() { 173 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 174 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 175 } 176 177 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 178 errZeroECDHResult := errors.New("ECDH returned all zeros") 179 180 device.staticIdentity.RLock() 181 defer device.staticIdentity.RUnlock() 182 183 handshake := &peer.handshake 184 handshake.mutex.Lock() 185 defer handshake.mutex.Unlock() 186 187 // create ephemeral key 188 var err error 189 handshake.hash = InitialHash 190 handshake.chainKey = InitialChainKey 191 handshake.localEphemeral, err = newPrivateKey() 192 if err != nil { 193 return nil, err 194 } 195 196 handshake.mixHash(handshake.remoteStatic[:]) 197 198 msg := MessageInitiation{ 199 Type: MessageInitiationType, 200 Ephemeral: handshake.localEphemeral.publicKey(), 201 } 202 203 handshake.mixKey(msg.Ephemeral[:]) 204 handshake.mixHash(msg.Ephemeral[:]) 205 206 // encrypt static key 207 ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 208 if isZero(ss[:]) { 209 return nil, errZeroECDHResult 210 } 211 var key [chacha20poly1305.KeySize]byte 212 KDF2( 213 &handshake.chainKey, 214 &key, 215 handshake.chainKey[:], 216 ss[:], 217 ) 218 aead, _ := chacha20poly1305.New(key[:]) 219 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 220 handshake.mixHash(msg.Static[:]) 221 222 // encrypt timestamp 223 if isZero(handshake.precomputedStaticStatic[:]) { 224 return nil, errZeroECDHResult 225 } 226 KDF2( 227 &handshake.chainKey, 228 &key, 229 handshake.chainKey[:], 230 handshake.precomputedStaticStatic[:], 231 ) 232 timestamp := tai64n.Now() 233 aead, _ = chacha20poly1305.New(key[:]) 234 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 235 236 // assign index 237 device.indexTable.Delete(handshake.localIndex) 238 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 239 if err != nil { 240 return nil, err 241 } 242 handshake.localIndex = msg.Sender 243 244 handshake.mixHash(msg.Timestamp[:]) 245 handshake.state = handshakeInitiationCreated 246 return &msg, nil 247 } 248 249 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 250 var ( 251 hash [blake2s.Size]byte 252 chainKey [blake2s.Size]byte 253 ) 254 255 if msg.Type != MessageInitiationType { 256 return nil 257 } 258 259 device.staticIdentity.RLock() 260 defer device.staticIdentity.RUnlock() 261 262 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 263 mixHash(&hash, &hash, msg.Ephemeral[:]) 264 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 265 266 // decrypt static key 267 var err error 268 var peerPK NoisePublicKey 269 var key [chacha20poly1305.KeySize]byte 270 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 271 if isZero(ss[:]) { 272 return nil 273 } 274 KDF2(&chainKey, &key, chainKey[:], ss[:]) 275 aead, _ := chacha20poly1305.New(key[:]) 276 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 277 if err != nil { 278 return nil 279 } 280 mixHash(&hash, &hash, msg.Static[:]) 281 282 // lookup peer 283 284 peer := device.LookupPeer(peerPK) 285 if peer == nil || !peer.isRunning.Get() { 286 return nil 287 } 288 289 handshake := &peer.handshake 290 291 // verify identity 292 293 var timestamp tai64n.Timestamp 294 295 handshake.mutex.RLock() 296 297 if isZero(handshake.precomputedStaticStatic[:]) { 298 handshake.mutex.RUnlock() 299 return nil 300 } 301 KDF2( 302 &chainKey, 303 &key, 304 chainKey[:], 305 handshake.precomputedStaticStatic[:], 306 ) 307 aead, _ = chacha20poly1305.New(key[:]) 308 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 309 if err != nil { 310 handshake.mutex.RUnlock() 311 return nil 312 } 313 mixHash(&hash, &hash, msg.Timestamp[:]) 314 315 // protect against replay & flood 316 317 replay := !timestamp.After(handshake.lastTimestamp) 318 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 319 handshake.mutex.RUnlock() 320 if replay { 321 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 322 return nil 323 } 324 if flood { 325 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 326 return nil 327 } 328 329 // update handshake state 330 331 handshake.mutex.Lock() 332 333 handshake.hash = hash 334 handshake.chainKey = chainKey 335 handshake.remoteIndex = msg.Sender 336 handshake.remoteEphemeral = msg.Ephemeral 337 if timestamp.After(handshake.lastTimestamp) { 338 handshake.lastTimestamp = timestamp 339 } 340 now := time.Now() 341 if now.After(handshake.lastInitiationConsumption) { 342 handshake.lastInitiationConsumption = now 343 } 344 handshake.state = handshakeInitiationConsumed 345 346 handshake.mutex.Unlock() 347 348 setZero(hash[:]) 349 setZero(chainKey[:]) 350 351 return peer 352 } 353 354 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 355 handshake := &peer.handshake 356 handshake.mutex.Lock() 357 defer handshake.mutex.Unlock() 358 359 if handshake.state != handshakeInitiationConsumed { 360 return nil, errors.New("handshake initiation must be consumed first") 361 } 362 363 // assign index 364 365 var err error 366 device.indexTable.Delete(handshake.localIndex) 367 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 368 if err != nil { 369 return nil, err 370 } 371 372 var msg MessageResponse 373 msg.Type = MessageResponseType 374 msg.Sender = handshake.localIndex 375 msg.Receiver = handshake.remoteIndex 376 377 // create ephemeral key 378 379 handshake.localEphemeral, err = newPrivateKey() 380 if err != nil { 381 return nil, err 382 } 383 msg.Ephemeral = handshake.localEphemeral.publicKey() 384 handshake.mixHash(msg.Ephemeral[:]) 385 handshake.mixKey(msg.Ephemeral[:]) 386 387 func() { 388 ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 389 handshake.mixKey(ss[:]) 390 ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 391 handshake.mixKey(ss[:]) 392 }() 393 394 // add preshared key 395 396 var tau [blake2s.Size]byte 397 var key [chacha20poly1305.KeySize]byte 398 399 KDF3( 400 &handshake.chainKey, 401 &tau, 402 &key, 403 handshake.chainKey[:], 404 handshake.presharedKey[:], 405 ) 406 407 handshake.mixHash(tau[:]) 408 409 func() { 410 aead, _ := chacha20poly1305.New(key[:]) 411 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 412 handshake.mixHash(msg.Empty[:]) 413 }() 414 415 handshake.state = handshakeResponseCreated 416 417 return &msg, nil 418 } 419 420 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 421 if msg.Type != MessageResponseType { 422 return nil 423 } 424 425 // lookup handshake by receiver 426 427 lookup := device.indexTable.Lookup(msg.Receiver) 428 handshake := lookup.handshake 429 if handshake == nil { 430 return nil 431 } 432 433 var ( 434 hash [blake2s.Size]byte 435 chainKey [blake2s.Size]byte 436 ) 437 438 ok := func() bool { 439 // lock handshake state 440 441 handshake.mutex.RLock() 442 defer handshake.mutex.RUnlock() 443 444 if handshake.state != handshakeInitiationCreated { 445 return false 446 } 447 448 // lock private key for reading 449 450 device.staticIdentity.RLock() 451 defer device.staticIdentity.RUnlock() 452 453 // finish 3-way DH 454 455 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 456 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 457 458 func() { 459 ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 460 mixKey(&chainKey, &chainKey, ss[:]) 461 setZero(ss[:]) 462 }() 463 464 func() { 465 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 466 mixKey(&chainKey, &chainKey, ss[:]) 467 setZero(ss[:]) 468 }() 469 470 // add preshared key (psk) 471 472 var tau [blake2s.Size]byte 473 var key [chacha20poly1305.KeySize]byte 474 KDF3( 475 &chainKey, 476 &tau, 477 &key, 478 chainKey[:], 479 handshake.presharedKey[:], 480 ) 481 mixHash(&hash, &hash, tau[:]) 482 483 // authenticate transcript 484 485 aead, _ := chacha20poly1305.New(key[:]) 486 _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 487 if err != nil { 488 return false 489 } 490 mixHash(&hash, &hash, msg.Empty[:]) 491 return true 492 }() 493 494 if !ok { 495 return nil 496 } 497 498 // update handshake state 499 500 handshake.mutex.Lock() 501 502 handshake.hash = hash 503 handshake.chainKey = chainKey 504 handshake.remoteIndex = msg.Sender 505 handshake.state = handshakeResponseConsumed 506 507 handshake.mutex.Unlock() 508 509 setZero(hash[:]) 510 setZero(chainKey[:]) 511 512 return lookup.peer 513 } 514 515 /* Derives a new keypair from the current handshake state 516 * 517 */ 518 func (peer *Peer) BeginSymmetricSession() error { 519 device := peer.device 520 handshake := &peer.handshake 521 handshake.mutex.Lock() 522 defer handshake.mutex.Unlock() 523 524 // derive keys 525 526 var isInitiator bool 527 var sendKey [chacha20poly1305.KeySize]byte 528 var recvKey [chacha20poly1305.KeySize]byte 529 530 if handshake.state == handshakeResponseConsumed { 531 KDF2( 532 &sendKey, 533 &recvKey, 534 handshake.chainKey[:], 535 nil, 536 ) 537 isInitiator = true 538 } else if handshake.state == handshakeResponseCreated { 539 KDF2( 540 &recvKey, 541 &sendKey, 542 handshake.chainKey[:], 543 nil, 544 ) 545 isInitiator = false 546 } else { 547 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 548 } 549 550 // zero handshake 551 552 setZero(handshake.chainKey[:]) 553 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 554 setZero(handshake.localEphemeral[:]) 555 peer.handshake.state = handshakeZeroed 556 557 // create AEAD instances 558 559 keypair := new(Keypair) 560 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 561 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 562 563 setZero(sendKey[:]) 564 setZero(recvKey[:]) 565 566 keypair.created = time.Now() 567 keypair.replayFilter.Reset() 568 keypair.isInitiator = isInitiator 569 keypair.localIndex = peer.handshake.localIndex 570 keypair.remoteIndex = peer.handshake.remoteIndex 571 572 // remap index 573 574 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 575 handshake.localIndex = 0 576 577 // rotate key pairs 578 579 keypairs := &peer.keypairs 580 keypairs.Lock() 581 defer keypairs.Unlock() 582 583 previous := keypairs.previous 584 next := keypairs.loadNext() 585 current := keypairs.current 586 587 if isInitiator { 588 if next != nil { 589 keypairs.storeNext(nil) 590 keypairs.previous = next 591 device.DeleteKeypair(current) 592 } else { 593 keypairs.previous = current 594 } 595 device.DeleteKeypair(previous) 596 keypairs.current = keypair 597 } else { 598 keypairs.storeNext(keypair) 599 device.DeleteKeypair(next) 600 keypairs.previous = nil 601 device.DeleteKeypair(previous) 602 } 603 604 return nil 605 } 606 607 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 608 keypairs := &peer.keypairs 609 610 if keypairs.loadNext() != receivedKeypair { 611 return false 612 } 613 keypairs.Lock() 614 defer keypairs.Unlock() 615 if keypairs.loadNext() != receivedKeypair { 616 return false 617 } 618 old := keypairs.previous 619 keypairs.previous = keypairs.current 620 peer.device.DeleteKeypair(old) 621 keypairs.current = keypairs.loadNext() 622 keypairs.storeNext(nil) 623 return true 624 }