github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/noise-protocol.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "errors" 10 "fmt" 11 "sync" 12 "time" 13 14 "golang.org/x/crypto/blake2s" 15 "golang.org/x/crypto/chacha20poly1305" 16 "golang.org/x/crypto/poly1305" 17 18 "github.com/tailscale/wireguard-go/tai64n" 19 ) 20 21 type handshakeState int 22 23 const ( 24 handshakeZeroed = handshakeState(iota) 25 handshakeInitiationCreated 26 handshakeInitiationConsumed 27 handshakeResponseCreated 28 handshakeResponseConsumed 29 ) 30 31 func (hs handshakeState) String() string { 32 switch hs { 33 case handshakeZeroed: 34 return "handshakeZeroed" 35 case handshakeInitiationCreated: 36 return "handshakeInitiationCreated" 37 case handshakeInitiationConsumed: 38 return "handshakeInitiationConsumed" 39 case handshakeResponseCreated: 40 return "handshakeResponseCreated" 41 case handshakeResponseConsumed: 42 return "handshakeResponseConsumed" 43 default: 44 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 45 } 46 } 47 48 const ( 49 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 50 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 51 WGLabelMAC1 = "mac1----" 52 WGLabelCookie = "cookie--" 53 ) 54 55 const ( 56 MessageInitiationType = 1 57 MessageResponseType = 2 58 MessageCookieReplyType = 3 59 MessageTransportType = 4 60 ) 61 62 const ( 63 MessageInitiationSize = 148 // size of handshake initiation message 64 MessageResponseSize = 92 // size of response message 65 MessageCookieReplySize = 64 // size of cookie reply message 66 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 67 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 68 MessageKeepaliveSize = MessageTransportSize // size of keepalive 69 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 70 ) 71 72 const ( 73 MessageTransportOffsetReceiver = 4 74 MessageTransportOffsetCounter = 8 75 MessageTransportOffsetContent = 16 76 ) 77 78 /* Type is an 8-bit field, followed by 3 nul bytes, 79 * by marshalling the messages in little-endian byteorder 80 * we can treat these as a 32-bit unsigned int (for now) 81 * 82 */ 83 84 type MessageInitiation struct { 85 Type uint32 86 Sender uint32 87 Ephemeral NoisePublicKey 88 Static [NoisePublicKeySize + poly1305.TagSize]byte 89 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 90 MAC1 [blake2s.Size128]byte 91 MAC2 [blake2s.Size128]byte 92 } 93 94 type MessageResponse struct { 95 Type uint32 96 Sender uint32 97 Receiver uint32 98 Ephemeral NoisePublicKey 99 Empty [poly1305.TagSize]byte 100 MAC1 [blake2s.Size128]byte 101 MAC2 [blake2s.Size128]byte 102 } 103 104 type MessageTransport struct { 105 Type uint32 106 Receiver uint32 107 Counter uint64 108 Content []byte 109 } 110 111 type MessageCookieReply struct { 112 Type uint32 113 Receiver uint32 114 Nonce [chacha20poly1305.NonceSizeX]byte 115 Cookie [blake2s.Size128 + poly1305.TagSize]byte 116 } 117 118 type Handshake struct { 119 state handshakeState 120 mutex sync.RWMutex 121 hash [blake2s.Size]byte // hash value 122 chainKey [blake2s.Size]byte // chain key 123 presharedKey NoisePresharedKey // psk 124 localEphemeral NoisePrivateKey // ephemeral secret key 125 localIndex uint32 // used to clear hash-table 126 remoteIndex uint32 // index for sending 127 remoteStatic NoisePublicKey // long term key 128 remoteEphemeral NoisePublicKey // ephemeral public key 129 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret 130 lastTimestamp tai64n.Timestamp 131 lastInitiationConsumption time.Time 132 lastSentHandshake time.Time 133 } 134 135 var ( 136 InitialChainKey [blake2s.Size]byte 137 InitialHash [blake2s.Size]byte 138 ZeroNonce [chacha20poly1305.NonceSize]byte 139 ) 140 141 func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { 142 KDF1(dst, c[:], data) 143 } 144 145 func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { 146 hash, _ := blake2s.New256(nil) 147 hash.Write(h[:]) 148 hash.Write(data) 149 hash.Sum(dst[:0]) 150 hash.Reset() 151 } 152 153 func (h *Handshake) Clear() { 154 setZero(h.localEphemeral[:]) 155 setZero(h.remoteEphemeral[:]) 156 setZero(h.chainKey[:]) 157 setZero(h.hash[:]) 158 h.localIndex = 0 159 h.state = handshakeZeroed 160 } 161 162 func (h *Handshake) mixHash(data []byte) { 163 mixHash(&h.hash, &h.hash, data) 164 } 165 166 func (h *Handshake) mixKey(data []byte) { 167 mixKey(&h.chainKey, &h.chainKey, data) 168 } 169 170 /* Do basic precomputations 171 */ 172 func init() { 173 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 174 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) 175 } 176 177 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 178 var errZeroECDHResult = errors.New("ECDH returned all zeros") 179 180 device.staticIdentity.RLock() 181 defer device.staticIdentity.RUnlock() 182 183 handshake := &peer.handshake 184 handshake.mutex.Lock() 185 defer handshake.mutex.Unlock() 186 187 // create ephemeral key 188 var err error 189 handshake.hash = InitialHash 190 handshake.chainKey = InitialChainKey 191 handshake.localEphemeral, err = newPrivateKey() 192 if err != nil { 193 return nil, err 194 } 195 196 handshake.mixHash(handshake.remoteStatic[:]) 197 198 msg := MessageInitiation{ 199 Type: MessageInitiationType, 200 Ephemeral: handshake.localEphemeral.publicKey(), 201 } 202 203 handshake.mixKey(msg.Ephemeral[:]) 204 handshake.mixHash(msg.Ephemeral[:]) 205 206 // encrypt static key 207 ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 208 if isZero(ss[:]) { 209 return nil, errZeroECDHResult 210 } 211 var key [chacha20poly1305.KeySize]byte 212 KDF2( 213 &handshake.chainKey, 214 &key, 215 handshake.chainKey[:], 216 ss[:], 217 ) 218 aead, _ := chacha20poly1305.New(key[:]) 219 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) 220 handshake.mixHash(msg.Static[:]) 221 222 // encrypt timestamp 223 if isZero(handshake.precomputedStaticStatic[:]) { 224 return nil, errZeroECDHResult 225 } 226 KDF2( 227 &handshake.chainKey, 228 &key, 229 handshake.chainKey[:], 230 handshake.precomputedStaticStatic[:], 231 ) 232 timestamp := tai64n.Now() 233 aead, _ = chacha20poly1305.New(key[:]) 234 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 235 236 // assign index 237 device.indexTable.Delete(handshake.localIndex) 238 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) 239 if err != nil { 240 return nil, err 241 } 242 handshake.localIndex = msg.Sender 243 244 handshake.mixHash(msg.Timestamp[:]) 245 handshake.state = handshakeInitiationCreated 246 return &msg, nil 247 } 248 249 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 250 var ( 251 hash [blake2s.Size]byte 252 chainKey [blake2s.Size]byte 253 ) 254 255 if msg.Type != MessageInitiationType { 256 return nil 257 } 258 259 device.staticIdentity.RLock() 260 defer device.staticIdentity.RUnlock() 261 262 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) 263 mixHash(&hash, &hash, msg.Ephemeral[:]) 264 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 265 266 // decrypt static key 267 var err error 268 var peerPK NoisePublicKey 269 var key [chacha20poly1305.KeySize]byte 270 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 271 if isZero(ss[:]) { 272 return nil 273 } 274 KDF2(&chainKey, &key, chainKey[:], ss[:]) 275 aead, _ := chacha20poly1305.New(key[:]) 276 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 277 if err != nil { 278 return nil 279 } 280 mixHash(&hash, &hash, msg.Static[:]) 281 282 // lookup peer 283 284 peer := device.LookupPeer(peerPK) 285 if peer == nil { 286 return nil 287 } 288 289 handshake := &peer.handshake 290 291 // verify identity 292 293 var timestamp tai64n.Timestamp 294 295 handshake.mutex.RLock() 296 297 if isZero(handshake.precomputedStaticStatic[:]) { 298 handshake.mutex.RUnlock() 299 return nil 300 } 301 KDF2( 302 &chainKey, 303 &key, 304 chainKey[:], 305 handshake.precomputedStaticStatic[:], 306 ) 307 aead, _ = chacha20poly1305.New(key[:]) 308 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 309 if err != nil { 310 handshake.mutex.RUnlock() 311 return nil 312 } 313 mixHash(&hash, &hash, msg.Timestamp[:]) 314 315 // protect against replay & flood 316 317 replay := !timestamp.After(handshake.lastTimestamp) 318 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 319 handshake.mutex.RUnlock() 320 if replay { 321 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) 322 return nil 323 } 324 if flood { 325 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) 326 return nil 327 } 328 329 // update handshake state 330 331 handshake.mutex.Lock() 332 333 handshake.hash = hash 334 handshake.chainKey = chainKey 335 handshake.remoteIndex = msg.Sender 336 handshake.remoteEphemeral = msg.Ephemeral 337 if timestamp.After(handshake.lastTimestamp) { 338 handshake.lastTimestamp = timestamp 339 } 340 now := time.Now() 341 if now.After(handshake.lastInitiationConsumption) { 342 handshake.lastInitiationConsumption = now 343 } 344 handshake.state = handshakeInitiationConsumed 345 346 handshake.mutex.Unlock() 347 348 setZero(hash[:]) 349 setZero(chainKey[:]) 350 351 return peer 352 } 353 354 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 355 handshake := &peer.handshake 356 handshake.mutex.Lock() 357 defer handshake.mutex.Unlock() 358 359 if handshake.state != handshakeInitiationConsumed { 360 return nil, errors.New("handshake initiation must be consumed first") 361 } 362 363 // assign index 364 365 var err error 366 device.indexTable.Delete(handshake.localIndex) 367 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) 368 if err != nil { 369 return nil, err 370 } 371 372 var msg MessageResponse 373 msg.Type = MessageResponseType 374 msg.Sender = handshake.localIndex 375 msg.Receiver = handshake.remoteIndex 376 377 // create ephemeral key 378 379 handshake.localEphemeral, err = newPrivateKey() 380 if err != nil { 381 return nil, err 382 } 383 msg.Ephemeral = handshake.localEphemeral.publicKey() 384 handshake.mixHash(msg.Ephemeral[:]) 385 handshake.mixKey(msg.Ephemeral[:]) 386 387 func() { 388 ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) 389 handshake.mixKey(ss[:]) 390 ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) 391 handshake.mixKey(ss[:]) 392 }() 393 394 // add preshared key 395 396 var tau [blake2s.Size]byte 397 var key [chacha20poly1305.KeySize]byte 398 399 KDF3( 400 &handshake.chainKey, 401 &tau, 402 &key, 403 handshake.chainKey[:], 404 handshake.presharedKey[:], 405 ) 406 407 handshake.mixHash(tau[:]) 408 409 func() { 410 aead, _ := chacha20poly1305.New(key[:]) 411 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 412 handshake.mixHash(msg.Empty[:]) 413 }() 414 415 handshake.state = handshakeResponseCreated 416 417 return &msg, nil 418 } 419 420 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { 421 if msg.Type != MessageResponseType { 422 return nil 423 } 424 425 // lookup handshake by receiver 426 427 lookup := device.indexTable.Lookup(msg.Receiver) 428 handshake := lookup.handshake 429 if handshake == nil { 430 return nil 431 } 432 433 var ( 434 hash [blake2s.Size]byte 435 chainKey [blake2s.Size]byte 436 ) 437 438 ok := func() bool { 439 440 // lock handshake state 441 442 handshake.mutex.RLock() 443 defer handshake.mutex.RUnlock() 444 445 if handshake.state != handshakeInitiationCreated { 446 return false 447 } 448 449 // lock private key for reading 450 451 device.staticIdentity.RLock() 452 defer device.staticIdentity.RUnlock() 453 454 // finish 3-way DH 455 456 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 457 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 458 459 func() { 460 ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) 461 mixKey(&chainKey, &chainKey, ss[:]) 462 setZero(ss[:]) 463 }() 464 465 func() { 466 ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) 467 mixKey(&chainKey, &chainKey, ss[:]) 468 setZero(ss[:]) 469 }() 470 471 // add preshared key (psk) 472 473 var tau [blake2s.Size]byte 474 var key [chacha20poly1305.KeySize]byte 475 KDF3( 476 &chainKey, 477 &tau, 478 &key, 479 chainKey[:], 480 handshake.presharedKey[:], 481 ) 482 mixHash(&hash, &hash, tau[:]) 483 484 // authenticate transcript 485 486 aead, _ := chacha20poly1305.New(key[:]) 487 _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 488 if err != nil { 489 return false 490 } 491 mixHash(&hash, &hash, msg.Empty[:]) 492 return true 493 }() 494 495 if !ok { 496 return nil 497 } 498 499 // update handshake state 500 501 handshake.mutex.Lock() 502 503 handshake.hash = hash 504 handshake.chainKey = chainKey 505 handshake.remoteIndex = msg.Sender 506 handshake.state = handshakeResponseConsumed 507 508 handshake.mutex.Unlock() 509 510 setZero(hash[:]) 511 setZero(chainKey[:]) 512 513 return lookup.peer 514 } 515 516 /* Derives a new keypair from the current handshake state 517 * 518 */ 519 func (peer *Peer) BeginSymmetricSession() error { 520 device := peer.device 521 handshake := &peer.handshake 522 handshake.mutex.Lock() 523 defer handshake.mutex.Unlock() 524 525 // derive keys 526 527 var isInitiator bool 528 var sendKey [chacha20poly1305.KeySize]byte 529 var recvKey [chacha20poly1305.KeySize]byte 530 531 if handshake.state == handshakeResponseConsumed { 532 KDF2( 533 &sendKey, 534 &recvKey, 535 handshake.chainKey[:], 536 nil, 537 ) 538 isInitiator = true 539 } else if handshake.state == handshakeResponseCreated { 540 KDF2( 541 &recvKey, 542 &sendKey, 543 handshake.chainKey[:], 544 nil, 545 ) 546 isInitiator = false 547 } else { 548 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 549 } 550 551 // zero handshake 552 553 setZero(handshake.chainKey[:]) 554 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 555 setZero(handshake.localEphemeral[:]) 556 peer.handshake.state = handshakeZeroed 557 558 // create AEAD instances 559 560 keypair := new(Keypair) 561 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 562 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 563 564 setZero(sendKey[:]) 565 setZero(recvKey[:]) 566 567 keypair.created = time.Now() 568 keypair.replayFilter.Reset() 569 keypair.isInitiator = isInitiator 570 keypair.localIndex = peer.handshake.localIndex 571 keypair.remoteIndex = peer.handshake.remoteIndex 572 573 // remap index 574 575 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 576 handshake.localIndex = 0 577 578 // rotate key pairs 579 580 keypairs := &peer.keypairs 581 keypairs.Lock() 582 defer keypairs.Unlock() 583 584 previous := keypairs.previous 585 next := keypairs.loadNext() 586 current := keypairs.current 587 588 if isInitiator { 589 if next != nil { 590 keypairs.storeNext(nil) 591 keypairs.previous = next 592 device.DeleteKeypair(current) 593 } else { 594 keypairs.previous = current 595 } 596 device.DeleteKeypair(previous) 597 keypairs.current = keypair 598 } else { 599 keypairs.storeNext(keypair) 600 device.DeleteKeypair(next) 601 keypairs.previous = nil 602 device.DeleteKeypair(previous) 603 } 604 605 return nil 606 } 607 608 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 609 keypairs := &peer.keypairs 610 611 if keypairs.loadNext() != receivedKeypair { 612 return false 613 } 614 keypairs.Lock() 615 defer keypairs.Unlock() 616 if keypairs.loadNext() != receivedKeypair { 617 return false 618 } 619 old := keypairs.previous 620 keypairs.previous = keypairs.current 621 peer.device.DeleteKeypair(old) 622 keypairs.current = keypairs.loadNext() 623 keypairs.storeNext(nil) 624 return true 625 }