github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/noise_protocol.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package transport 33 34 import ( 35 "errors" 36 "fmt" 37 "log/slog" 38 "sync" 39 "time" 40 41 "github.com/noisysockets/noisysockets/internal/tai64n" 42 "github.com/noisysockets/noisysockets/types" 43 "golang.org/x/crypto/blake2s" 44 "golang.org/x/crypto/chacha20poly1305" 45 46 //nolint:staticcheck 47 "golang.org/x/crypto/poly1305" 48 ) 49 50 type handshakeState int 51 52 const ( 53 handshakeZeroed = handshakeState(iota) 54 handshakeInitiationCreated 55 handshakeInitiationConsumed 56 handshakeResponseCreated 57 handshakeResponseConsumed 58 ) 59 60 func (hs handshakeState) String() string { 61 switch hs { 62 case handshakeZeroed: 63 return "handshakeZeroed" 64 case handshakeInitiationCreated: 65 return "handshakeInitiationCreated" 66 case handshakeInitiationConsumed: 67 return "handshakeInitiationConsumed" 68 case handshakeResponseCreated: 69 return "handshakeResponseCreated" 70 case handshakeResponseConsumed: 71 return "handshakeResponseConsumed" 72 default: 73 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) 74 } 75 } 76 77 const ( 78 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" 79 NoiseIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" 80 NoiseLabelMAC1 = "mac1----" 81 NoiseLabelCookie = "cookie--" 82 ) 83 84 const ( 85 MessageInitiationType = 1 86 MessageResponseType = 2 87 MessageCookieReplyType = 3 88 MessageTransportType = 4 89 ) 90 91 const ( 92 MessageInitiationSize = 148 // size of handshake initiation message 93 MessageResponseSize = 92 // size of response message 94 MessageCookieReplySize = 64 // size of cookie reply message 95 MessageTransportHeaderSize = 16 // size of data preceding content in transport message 96 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport 97 MessageKeepaliveSize = MessageTransportSize // size of keepalive 98 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message 99 ) 100 101 const ( 102 MessageTransportOffsetReceiver = 4 103 MessageTransportOffsetCounter = 8 104 MessageTransportOffsetContent = 16 105 ) 106 107 /* Type is an 8-bit field, followed by 3 nul bytes, 108 * by marshalling the messages in little-endian byteorder 109 * we can treat these as a 32-bit unsigned int (for now) 110 * 111 */ 112 113 type MessageInitiation struct { 114 Type uint32 115 Sender uint32 116 Ephemeral types.NoisePublicKey 117 Static [types.NoisePublicKeySize + poly1305.TagSize]byte 118 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte 119 MAC1 [blake2s.Size128]byte 120 MAC2 [blake2s.Size128]byte 121 } 122 123 type MessageResponse struct { 124 Type uint32 125 Sender uint32 126 Receiver uint32 127 Ephemeral types.NoisePublicKey 128 Empty [poly1305.TagSize]byte 129 MAC1 [blake2s.Size128]byte 130 MAC2 [blake2s.Size128]byte 131 } 132 133 type MessageTransport struct { 134 Type uint32 135 Receiver uint32 136 Counter uint64 137 Content []byte 138 } 139 140 type MessageCookieReply struct { 141 Type uint32 142 Receiver uint32 143 Nonce [chacha20poly1305.NonceSizeX]byte 144 Cookie [blake2s.Size128 + poly1305.TagSize]byte 145 } 146 147 type Handshake struct { 148 state handshakeState 149 mutex sync.RWMutex 150 hash [blake2s.Size]byte // hash value 151 chainKey [blake2s.Size]byte // chain key 152 presharedKey types.NoisePresharedKey // psk 153 localEphemeral types.NoisePrivateKey // ephemeral secret key 154 localIndex uint32 // used to clear hash-table 155 remoteIndex uint32 // index for sending 156 remoteStatic types.NoisePublicKey // long term key 157 remoteEphemeral types.NoisePublicKey // ephemeral public key 158 precomputedStaticStatic [types.NoisePublicKeySize]byte // precomputed shared secret 159 lastTimestamp tai64n.Timestamp 160 lastInitiationConsumption time.Time 161 lastSentHandshake time.Time 162 } 163 164 var ( 165 InitialChainKey [blake2s.Size]byte 166 InitialHash [blake2s.Size]byte 167 ZeroNonce [chacha20poly1305.NonceSize]byte 168 ) 169 170 func mixKey(dst, c *[blake2s.Size]byte, data []byte) { 171 KDF1(dst, c[:], data) 172 } 173 174 func mixHash(dst, h *[blake2s.Size]byte, data []byte) { 175 hash, _ := blake2s.New256(nil) 176 hash.Write(h[:]) 177 hash.Write(data) 178 hash.Sum(dst[:0]) 179 hash.Reset() 180 } 181 182 func (h *Handshake) Clear() { 183 setZero(h.localEphemeral[:]) 184 setZero(h.remoteEphemeral[:]) 185 setZero(h.chainKey[:]) 186 setZero(h.hash[:]) 187 h.localIndex = 0 188 h.state = handshakeZeroed 189 } 190 191 func (h *Handshake) mixHash(data []byte) { 192 mixHash(&h.hash, &h.hash, data) 193 } 194 195 func (h *Handshake) mixKey(data []byte) { 196 mixKey(&h.chainKey, &h.chainKey, data) 197 } 198 199 /* Do basic precomputations 200 */ 201 func init() { 202 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) 203 mixHash(&InitialHash, &InitialChainKey, []byte(NoiseIdentifier)) 204 } 205 206 func (transport *Transport) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { 207 transport.staticIdentity.RLock() 208 defer transport.staticIdentity.RUnlock() 209 210 handshake := &peer.handshake 211 handshake.mutex.Lock() 212 defer handshake.mutex.Unlock() 213 214 // create ephemeral key 215 var err error 216 handshake.hash = InitialHash 217 handshake.chainKey = InitialChainKey 218 handshake.localEphemeral, err = types.NewPrivateKey() 219 if err != nil { 220 return nil, err 221 } 222 223 handshake.mixHash(handshake.remoteStatic[:]) 224 225 msg := MessageInitiation{ 226 Type: MessageInitiationType, 227 Ephemeral: handshake.localEphemeral.Public(), 228 } 229 230 handshake.mixKey(msg.Ephemeral[:]) 231 handshake.mixHash(msg.Ephemeral[:]) 232 233 // encrypt static key 234 ss, err := sharedSecret(handshake.localEphemeral, handshake.remoteStatic) 235 if err != nil { 236 return nil, err 237 } 238 var key [chacha20poly1305.KeySize]byte 239 KDF2( 240 &handshake.chainKey, 241 &key, 242 handshake.chainKey[:], 243 ss[:], 244 ) 245 aead, _ := chacha20poly1305.New(key[:]) 246 aead.Seal(msg.Static[:0], ZeroNonce[:], transport.staticIdentity.publicKey[:], handshake.hash[:]) 247 handshake.mixHash(msg.Static[:]) 248 249 // encrypt timestamp 250 if isZero(handshake.precomputedStaticStatic[:]) { 251 return nil, errInvalidPublicKey 252 } 253 KDF2( 254 &handshake.chainKey, 255 &key, 256 handshake.chainKey[:], 257 handshake.precomputedStaticStatic[:], 258 ) 259 timestamp := tai64n.Now() 260 aead, _ = chacha20poly1305.New(key[:]) 261 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) 262 263 // assign index 264 transport.indexTable.Delete(handshake.localIndex) 265 msg.Sender, err = transport.indexTable.NewIndexForHandshake(peer, handshake) 266 if err != nil { 267 return nil, err 268 } 269 handshake.localIndex = msg.Sender 270 271 handshake.mixHash(msg.Timestamp[:]) 272 handshake.state = handshakeInitiationCreated 273 return &msg, nil 274 } 275 276 func (transport *Transport) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { 277 var ( 278 hash [blake2s.Size]byte 279 chainKey [blake2s.Size]byte 280 ) 281 282 if msg.Type != MessageInitiationType { 283 return nil 284 } 285 286 transport.staticIdentity.RLock() 287 defer transport.staticIdentity.RUnlock() 288 289 mixHash(&hash, &InitialHash, transport.staticIdentity.publicKey[:]) 290 mixHash(&hash, &hash, msg.Ephemeral[:]) 291 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) 292 293 // decrypt static key 294 var peerPK types.NoisePublicKey 295 var key [chacha20poly1305.KeySize]byte 296 ss, err := sharedSecret(transport.staticIdentity.privateKey, msg.Ephemeral) 297 if err != nil { 298 return nil 299 } 300 KDF2(&chainKey, &key, chainKey[:], ss[:]) 301 aead, _ := chacha20poly1305.New(key[:]) 302 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) 303 if err != nil { 304 return nil 305 } 306 mixHash(&hash, &hash, msg.Static[:]) 307 308 // lookup peer 309 310 peer := transport.LookupPeer(peerPK) 311 if peer == nil || !peer.isRunning.Load() { 312 return nil 313 } 314 315 handshake := &peer.handshake 316 317 // verify identity 318 319 var timestamp tai64n.Timestamp 320 321 handshake.mutex.RLock() 322 323 if isZero(handshake.precomputedStaticStatic[:]) { 324 handshake.mutex.RUnlock() 325 return nil 326 } 327 KDF2( 328 &chainKey, 329 &key, 330 chainKey[:], 331 handshake.precomputedStaticStatic[:], 332 ) 333 aead, _ = chacha20poly1305.New(key[:]) 334 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) 335 if err != nil { 336 handshake.mutex.RUnlock() 337 return nil 338 } 339 mixHash(&hash, &hash, msg.Timestamp[:]) 340 341 // protect against replay & flood 342 343 replay := !timestamp.After(handshake.lastTimestamp) 344 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate 345 handshake.mutex.RUnlock() 346 if replay { 347 transport.logger.Debug("ConsumeMessageInitiation: handshake replay", 348 slog.String("peer", peer.String())) 349 return nil 350 } 351 if flood { 352 transport.logger.Debug("ConsumeMessageInitiation: handshake flood", 353 slog.String("peer", peer.String())) 354 return nil 355 } 356 357 // update handshake state 358 359 handshake.mutex.Lock() 360 361 handshake.hash = hash 362 handshake.chainKey = chainKey 363 handshake.remoteIndex = msg.Sender 364 handshake.remoteEphemeral = msg.Ephemeral 365 if timestamp.After(handshake.lastTimestamp) { 366 handshake.lastTimestamp = timestamp 367 } 368 now := time.Now() 369 if now.After(handshake.lastInitiationConsumption) { 370 handshake.lastInitiationConsumption = now 371 } 372 handshake.state = handshakeInitiationConsumed 373 374 handshake.mutex.Unlock() 375 376 setZero(hash[:]) 377 setZero(chainKey[:]) 378 379 return peer 380 } 381 382 func (transport *Transport) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { 383 handshake := &peer.handshake 384 handshake.mutex.Lock() 385 defer handshake.mutex.Unlock() 386 387 if handshake.state != handshakeInitiationConsumed { 388 return nil, errors.New("handshake initiation must be consumed first") 389 } 390 391 // assign index 392 393 var err error 394 transport.indexTable.Delete(handshake.localIndex) 395 handshake.localIndex, err = transport.indexTable.NewIndexForHandshake(peer, handshake) 396 if err != nil { 397 return nil, err 398 } 399 400 var msg MessageResponse 401 msg.Type = MessageResponseType 402 msg.Sender = handshake.localIndex 403 msg.Receiver = handshake.remoteIndex 404 405 // create ephemeral key 406 407 handshake.localEphemeral, err = types.NewPrivateKey() 408 if err != nil { 409 return nil, err 410 } 411 msg.Ephemeral = handshake.localEphemeral.Public() 412 handshake.mixHash(msg.Ephemeral[:]) 413 handshake.mixKey(msg.Ephemeral[:]) 414 415 ss, err := sharedSecret(handshake.localEphemeral, handshake.remoteEphemeral) 416 if err != nil { 417 return nil, err 418 } 419 handshake.mixKey(ss[:]) 420 ss, err = sharedSecret(handshake.localEphemeral, handshake.remoteStatic) 421 if err != nil { 422 return nil, err 423 } 424 handshake.mixKey(ss[:]) 425 426 // add preshared key 427 428 var tau [blake2s.Size]byte 429 var key [chacha20poly1305.KeySize]byte 430 431 KDF3( 432 &handshake.chainKey, 433 &tau, 434 &key, 435 handshake.chainKey[:], 436 handshake.presharedKey[:], 437 ) 438 439 handshake.mixHash(tau[:]) 440 441 aead, _ := chacha20poly1305.New(key[:]) 442 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) 443 handshake.mixHash(msg.Empty[:]) 444 445 handshake.state = handshakeResponseCreated 446 447 return &msg, nil 448 } 449 450 func (transport *Transport) ConsumeMessageResponse(msg *MessageResponse) *Peer { 451 if msg.Type != MessageResponseType { 452 transport.logger.Debug("ConsumeMessageResponse: invalid message type", 453 slog.Int("type", int(msg.Type))) 454 return nil 455 } 456 457 // lookup handshake by receiver 458 459 lookup := transport.indexTable.Lookup(msg.Receiver) 460 handshake := lookup.handshake 461 if handshake == nil { 462 transport.logger.Debug("ConsumeMessageResponse: no handshake found for receiver", 463 slog.Int("receiver", int(msg.Receiver))) 464 return nil 465 } 466 467 var ( 468 hash [blake2s.Size]byte 469 chainKey [blake2s.Size]byte 470 ) 471 472 ok := func() bool { 473 // lock handshake state 474 475 handshake.mutex.RLock() 476 defer handshake.mutex.RUnlock() 477 478 if handshake.state != handshakeInitiationCreated { 479 transport.logger.Debug("ConsumeMessageResponse: invalid state", 480 slog.String("peer", lookup.peer.String()), 481 slog.String("handshakeState", handshake.state.String())) 482 return false 483 } 484 485 // lock private key for reading 486 487 transport.staticIdentity.RLock() 488 defer transport.staticIdentity.RUnlock() 489 490 // finish 3-way DH 491 492 mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) 493 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) 494 495 ss, err := sharedSecret(handshake.localEphemeral, msg.Ephemeral) 496 if err != nil { 497 transport.logger.Debug("ConsumeMessageResponse: failed to compute shared secret", 498 slog.String("peer", lookup.peer.String())) 499 return false 500 } 501 mixKey(&chainKey, &chainKey, ss[:]) 502 setZero(ss[:]) 503 504 ss, err = sharedSecret(transport.staticIdentity.privateKey, msg.Ephemeral) 505 if err != nil { 506 transport.logger.Debug("ConsumeMessageResponse: failed to compute shared secret", 507 slog.String("peer", lookup.peer.String())) 508 return false 509 } 510 mixKey(&chainKey, &chainKey, ss[:]) 511 setZero(ss[:]) 512 513 // add preshared key (psk) 514 515 var tau [blake2s.Size]byte 516 var key [chacha20poly1305.KeySize]byte 517 KDF3( 518 &chainKey, 519 &tau, 520 &key, 521 chainKey[:], 522 handshake.presharedKey[:], 523 ) 524 mixHash(&hash, &hash, tau[:]) 525 526 // authenticate transcript 527 528 aead, _ := chacha20poly1305.New(key[:]) 529 _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) 530 if err != nil { 531 transport.logger.Debug("ConsumeMessageResponse: failed to authenticate transcript", 532 slog.String("peer", lookup.peer.String())) 533 return false 534 } 535 mixHash(&hash, &hash, msg.Empty[:]) 536 return true 537 }() 538 539 if !ok { 540 return nil 541 } 542 543 // update handshake state 544 545 handshake.mutex.Lock() 546 547 handshake.hash = hash 548 handshake.chainKey = chainKey 549 handshake.remoteIndex = msg.Sender 550 handshake.state = handshakeResponseConsumed 551 552 handshake.mutex.Unlock() 553 554 setZero(hash[:]) 555 setZero(chainKey[:]) 556 557 return lookup.peer 558 } 559 560 /* Derives a new keypair from the current handshake state 561 * 562 */ 563 func (peer *Peer) BeginSymmetricSession() error { 564 transport := peer.transport 565 handshake := &peer.handshake 566 handshake.mutex.Lock() 567 defer handshake.mutex.Unlock() 568 569 // derive keys 570 571 var isInitiator bool 572 var sendKey [chacha20poly1305.KeySize]byte 573 var recvKey [chacha20poly1305.KeySize]byte 574 575 if handshake.state == handshakeResponseConsumed { 576 KDF2( 577 &sendKey, 578 &recvKey, 579 handshake.chainKey[:], 580 nil, 581 ) 582 isInitiator = true 583 } else if handshake.state == handshakeResponseCreated { 584 KDF2( 585 &recvKey, 586 &sendKey, 587 handshake.chainKey[:], 588 nil, 589 ) 590 isInitiator = false 591 } else { 592 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) 593 } 594 595 // zero handshake 596 597 setZero(handshake.chainKey[:]) 598 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. 599 setZero(handshake.localEphemeral[:]) 600 peer.handshake.state = handshakeZeroed 601 602 // create AEAD instances 603 604 keypair := new(Keypair) 605 keypair.send, _ = chacha20poly1305.New(sendKey[:]) 606 keypair.receive, _ = chacha20poly1305.New(recvKey[:]) 607 608 setZero(sendKey[:]) 609 setZero(recvKey[:]) 610 611 keypair.created = time.Now() 612 keypair.replayFilter.Reset() 613 keypair.isInitiator = isInitiator 614 keypair.localIndex = peer.handshake.localIndex 615 keypair.remoteIndex = peer.handshake.remoteIndex 616 617 // remap index 618 619 transport.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) 620 handshake.localIndex = 0 621 622 // rotate key pairs 623 624 keypairs := &peer.keypairs 625 keypairs.Lock() 626 defer keypairs.Unlock() 627 628 previous := keypairs.previous 629 next := keypairs.next.Load() 630 current := keypairs.current 631 632 if isInitiator { 633 if next != nil { 634 keypairs.next.Store(nil) 635 keypairs.previous = next 636 transport.DeleteKeypair(current) 637 } else { 638 keypairs.previous = current 639 } 640 transport.DeleteKeypair(previous) 641 keypairs.current = keypair 642 } else { 643 keypairs.next.Store(keypair) 644 transport.DeleteKeypair(next) 645 keypairs.previous = nil 646 transport.DeleteKeypair(previous) 647 } 648 649 return nil 650 } 651 652 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { 653 keypairs := &peer.keypairs 654 655 if keypairs.next.Load() != receivedKeypair { 656 return false 657 } 658 keypairs.Lock() 659 defer keypairs.Unlock() 660 if keypairs.next.Load() != receivedKeypair { 661 return false 662 } 663 old := keypairs.previous 664 keypairs.previous = keypairs.current 665 peer.transport.DeleteKeypair(old) 666 keypairs.current = keypairs.next.Load() 667 keypairs.next.Store(nil) 668 return true 669 }