github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/send.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package transport 33 34 import ( 35 "bytes" 36 "encoding/binary" 37 "errors" 38 "log/slog" 39 "os" 40 "sync" 41 "time" 42 43 "github.com/noisysockets/noisysockets/internal/conn" 44 "github.com/noisysockets/noisysockets/types" 45 "golang.org/x/crypto/chacha20poly1305" 46 ) 47 48 const DefaultMTU = 1420 49 50 /* Outbound flow 51 * 52 * 1. Source queue 53 * 2. Routing (sequential) 54 * 3. Nonce assignment (sequential) 55 * 4. Encryption (parallel) 56 * 5. Transmission (sequential) 57 * 58 * The functions in this file occur (roughly) in the order in 59 * which the packets are processed. 60 * 61 * Locking, Producers and Consumers 62 * 63 * The order of packets (per peer) must be maintained, 64 * but encryption of packets happen out-of-order: 65 * 66 * The sequential consumers will attempt to take the lock, 67 * workers release lock when they have completed work (encryption) on the packet. 68 * 69 * If the element is inserted into the "encryption queue", 70 * the content is preceded by enough "junk" to contain the transport header 71 * (to allow the construction of transport messages in-place) 72 */ 73 74 type QueueOutboundElement struct { 75 buffer *[MaxMessageSize]byte // slice holding the packet data 76 packet []byte // slice of "buffer" (always!) 77 nonce uint64 // nonce for encryption 78 keypair *Keypair // keypair for encryption 79 peer *Peer // related peer 80 } 81 82 type QueueOutboundElementsContainer struct { 83 sync.Mutex 84 elems []*QueueOutboundElement 85 } 86 87 func (transport *Transport) NewOutboundElement() *QueueOutboundElement { 88 elem := transport.GetOutboundElement() 89 elem.buffer = transport.GetMessageBuffer() 90 elem.nonce = 0 91 // keypair and peer were cleared (if necessary) by clearPointers. 92 return elem 93 } 94 95 // clearPointers clears elem fields that contain pointers. 96 // This makes the garbage collector's life easier and 97 // avoids accidentally keeping other objects around unnecessarily. 98 // It also reduces the possible collateral damage from use-after-free bugs. 99 func (elem *QueueOutboundElement) clearPointers() { 100 elem.buffer = nil 101 elem.packet = nil 102 elem.keypair = nil 103 elem.peer = nil 104 } 105 106 /* Queues a keepalive if no packets are queued for peer 107 */ 108 func (peer *Peer) SendKeepalive() error { 109 if len(peer.queue.staged) == 0 && peer.isRunning.Load() { 110 elem := peer.transport.NewOutboundElement() 111 elemsContainer := peer.transport.GetOutboundElementsContainer() 112 elemsContainer.elems = append(elemsContainer.elems, elem) 113 select { 114 case peer.queue.staged <- elemsContainer: 115 peer.transport.logger.Debug("Sending keepalive packet", slog.String("peer", peer.String())) 116 default: 117 peer.transport.PutMessageBuffer(elem.buffer) 118 peer.transport.PutOutboundElement(elem) 119 peer.transport.PutOutboundElementsContainer(elemsContainer) 120 } 121 } 122 return peer.SendStagedPackets() 123 } 124 125 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { 126 logger := peer.transport.logger.With(slog.String("peer", peer.String())) 127 128 peer.endpoint.Lock() 129 endpoint := peer.endpoint.val 130 peer.endpoint.Unlock() 131 132 // If we don't have an endpoint, ignore the request. 133 if endpoint == nil { 134 return nil 135 } 136 137 if !isRetry { 138 peer.timers.handshakeAttempts.Store(0) 139 } 140 141 peer.handshake.mutex.RLock() 142 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 143 peer.handshake.mutex.RUnlock() 144 return nil 145 } 146 peer.handshake.mutex.RUnlock() 147 148 peer.handshake.mutex.Lock() 149 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout { 150 peer.handshake.mutex.Unlock() 151 return nil 152 } 153 peer.handshake.lastSentHandshake = time.Now() 154 peer.handshake.mutex.Unlock() 155 156 logger.Debug("Sending handshake initiation") 157 158 msg, err := peer.transport.CreateMessageInitiation(peer) 159 if err != nil { 160 logger.Error("Failed to create initiation message", slog.Any("error", err)) 161 return err 162 } 163 164 var buf [MessageInitiationSize]byte 165 writer := bytes.NewBuffer(buf[:0]) 166 if err := binary.Write(writer, binary.LittleEndian, msg); err != nil { 167 logger.Error("Failed to write initiation message", slog.Any("error", err)) 168 return err 169 } 170 171 packet := writer.Bytes() 172 peer.cookieGenerator.AddMacs(packet) 173 174 peer.timersAnyAuthenticatedPacketTraversal() 175 peer.timersAnyAuthenticatedPacketSent() 176 177 err = peer.SendBuffers([][]byte{packet}) 178 if err != nil { 179 logger.Error("Failed to send handshake initiation", slog.Any("error", err)) 180 } 181 peer.timersHandshakeInitiated() 182 183 return err 184 } 185 186 func (peer *Peer) SendHandshakeResponse() error { 187 logger := peer.transport.logger.With(slog.String("peer", peer.String())) 188 189 peer.handshake.mutex.Lock() 190 peer.handshake.lastSentHandshake = time.Now() 191 peer.handshake.mutex.Unlock() 192 193 logger.Debug("Sending handshake response") 194 195 response, err := peer.transport.CreateMessageResponse(peer) 196 if err != nil { 197 logger.Error("Failed to create handshake response message", slog.Any("error", err)) 198 return err 199 } 200 201 var buf [MessageResponseSize]byte 202 writer := bytes.NewBuffer(buf[:0]) 203 if err := binary.Write(writer, binary.LittleEndian, response); err != nil { 204 logger.Error("Failed to write handshake response message", slog.Any("error", err)) 205 return err 206 } 207 208 packet := writer.Bytes() 209 peer.cookieGenerator.AddMacs(packet) 210 211 err = peer.BeginSymmetricSession() 212 if err != nil { 213 logger.Error("Failed to derive keypair", slog.Any("error", err)) 214 return err 215 } 216 217 peer.timersSessionDerived() 218 peer.timersAnyAuthenticatedPacketTraversal() 219 peer.timersAnyAuthenticatedPacketSent() 220 221 // TODO: allocation could be avoided 222 err = peer.SendBuffers([][]byte{packet}) 223 if err != nil { 224 logger.Error("Failed to send handshake response", slog.Any("error", err)) 225 } 226 return err 227 } 228 229 func (transport *Transport) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { 230 logger := transport.logger.With(slog.String("source", initiatingElem.endpoint.DstToString())) 231 232 logger.Debug("Sending cookie response for denied handshake message") 233 234 sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) 235 reply, err := transport.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) 236 if err != nil { 237 logger.Error("Failed to create cookie reply", slog.Any("error", err)) 238 return err 239 } 240 241 var buf [MessageCookieReplySize]byte 242 writer := bytes.NewBuffer(buf[:0]) 243 if err := binary.Write(writer, binary.LittleEndian, reply); err != nil { 244 logger.Error("Failed to write cookie reply", slog.Any("error", err)) 245 return err 246 } 247 248 // TODO: allocation could be avoided 249 return transport.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) 250 } 251 252 func (peer *Peer) keepKeyFreshSending() error { 253 keypair := peer.keypairs.Current() 254 if keypair == nil { 255 return nil 256 } 257 258 nonce := keypair.sendNonce.Load() 259 if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { 260 return peer.SendHandshakeInitiation(false) 261 } 262 263 return nil 264 } 265 266 func (transport *Transport) RoutineReadFromSourceSink() { 267 defer func() { 268 transport.logger.Debug("Routine: Source reader - stopped") 269 transport.state.stopping.Done() 270 transport.queue.encryption.wg.Done() 271 }() 272 273 transport.logger.Debug("Routine: Source reader - started") 274 275 var ( 276 batchSize = transport.BatchSize() 277 readErr error 278 elems = make([]*QueueOutboundElement, batchSize) 279 bufs = make([][]byte, batchSize) 280 peers = make([]types.NoisePublicKey, batchSize) 281 elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) 282 count int 283 sizes = make([]int, batchSize) 284 offset = MessageTransportHeaderSize 285 ) 286 287 for i := range elems { 288 elems[i] = transport.NewOutboundElement() 289 bufs[i] = elems[i].buffer[:] 290 } 291 292 defer func() { 293 for _, elem := range elems { 294 if elem != nil { 295 transport.PutMessageBuffer(elem.buffer) 296 transport.PutOutboundElement(elem) 297 } 298 } 299 }() 300 301 for { 302 // read packets 303 count, readErr = transport.sourceSink.Read(bufs, sizes, peers, offset) 304 for i := 0; i < count; i++ { 305 if sizes[i] < 1 { 306 continue 307 } 308 309 elem := elems[i] 310 elem.packet = bufs[i][offset : offset+sizes[i]] 311 312 transport.peers.RLock() 313 peer := transport.peers.keyMap[peers[i]] 314 transport.peers.RUnlock() 315 if peer == nil { 316 continue 317 } 318 319 elemsForPeer, ok := elemsByPeer[peer] 320 if !ok { 321 elemsForPeer = transport.GetOutboundElementsContainer() 322 elemsByPeer[peer] = elemsForPeer 323 } 324 elemsForPeer.elems = append(elemsForPeer.elems, elem) 325 elems[i] = transport.NewOutboundElement() 326 bufs[i] = elems[i].buffer[:] 327 } 328 329 for peer, elemsForPeer := range elemsByPeer { 330 if peer.isRunning.Load() { 331 peer.StagePackets(elemsForPeer) 332 if err := peer.SendStagedPackets(); err != nil { 333 transport.logger.Warn("Failed to send staged packets", 334 slog.String("peer", peer.String()), slog.Any("error", err)) 335 continue 336 } 337 } else { 338 for _, elem := range elemsForPeer.elems { 339 transport.PutMessageBuffer(elem.buffer) 340 transport.PutOutboundElement(elem) 341 } 342 transport.PutOutboundElementsContainer(elemsForPeer) 343 } 344 delete(elemsByPeer, peer) 345 } 346 347 if readErr != nil { 348 if !transport.isClosed() { 349 if !errors.Is(readErr, os.ErrClosed) { 350 transport.logger.Error("Failed to read packet from source sink", 351 slog.Any("error", readErr)) 352 } 353 go transport.Close() 354 } 355 return 356 } 357 } 358 } 359 360 func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { 361 for { 362 select { 363 case peer.queue.staged <- elems: 364 return 365 default: 366 } 367 select { 368 case tooOld := <-peer.queue.staged: 369 for _, elem := range tooOld.elems { 370 peer.transport.PutMessageBuffer(elem.buffer) 371 peer.transport.PutOutboundElement(elem) 372 } 373 peer.transport.PutOutboundElementsContainer(tooOld) 374 default: 375 } 376 } 377 } 378 379 func (peer *Peer) SendStagedPackets() error { 380 top: 381 if len(peer.queue.staged) == 0 || !peer.transport.isUp() { 382 return nil 383 } 384 385 keypair := peer.keypairs.Current() 386 if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { 387 return peer.SendHandshakeInitiation(false) 388 } 389 390 for { 391 var elemsContainerOOO *QueueOutboundElementsContainer 392 select { 393 case elemsContainer := <-peer.queue.staged: 394 i := 0 395 for _, elem := range elemsContainer.elems { 396 elem.peer = peer 397 elem.nonce = keypair.sendNonce.Add(1) - 1 398 if elem.nonce >= RejectAfterMessages { 399 keypair.sendNonce.Store(RejectAfterMessages) 400 if elemsContainerOOO == nil { 401 elemsContainerOOO = peer.transport.GetOutboundElementsContainer() 402 } 403 elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) 404 continue 405 } else { 406 elemsContainer.elems[i] = elem 407 i++ 408 } 409 410 elem.keypair = keypair 411 } 412 elemsContainer.Lock() 413 elemsContainer.elems = elemsContainer.elems[:i] 414 415 if elemsContainerOOO != nil { 416 peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans 417 } 418 419 if len(elemsContainer.elems) == 0 { 420 peer.transport.PutOutboundElementsContainer(elemsContainer) 421 goto top 422 } 423 424 // add to parallel and sequential queue 425 if peer.isRunning.Load() { 426 peer.queue.outbound.c <- elemsContainer 427 peer.transport.queue.encryption.c <- elemsContainer 428 } else { 429 for _, elem := range elemsContainer.elems { 430 peer.transport.PutMessageBuffer(elem.buffer) 431 peer.transport.PutOutboundElement(elem) 432 } 433 peer.transport.PutOutboundElementsContainer(elemsContainer) 434 } 435 436 if elemsContainerOOO != nil { 437 goto top 438 } 439 default: 440 return nil 441 } 442 } 443 } 444 445 func (peer *Peer) FlushStagedPackets() { 446 for { 447 select { 448 case elemsContainer := <-peer.queue.staged: 449 for _, elem := range elemsContainer.elems { 450 peer.transport.PutMessageBuffer(elem.buffer) 451 peer.transport.PutOutboundElement(elem) 452 } 453 peer.transport.PutOutboundElementsContainer(elemsContainer) 454 default: 455 return 456 } 457 } 458 } 459 460 func calculatePaddingSize(packetSize, mtu int) int { 461 lastUnit := packetSize 462 if mtu == 0 { 463 return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit 464 } 465 if lastUnit > mtu { 466 lastUnit %= mtu 467 } 468 paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) 469 if paddedSize > mtu { 470 paddedSize = mtu 471 } 472 return paddedSize - lastUnit 473 } 474 475 /* Encrypts the elements in the queue 476 * and marks them for sequential consumption (by releasing the mutex) 477 * 478 * Obs. One instance per core 479 */ 480 func (transport *Transport) RoutineEncryption(id int) { 481 var paddingZeros [PaddingMultiple]byte 482 var nonce [chacha20poly1305.NonceSize]byte 483 484 logger := transport.logger.With(slog.Int("id", id)) 485 486 defer logger.Debug("Routine: encryption worker - stopped") 487 logger.Debug("Routine: encryption worker - started") 488 489 for elemsContainer := range transport.queue.encryption.c { 490 for _, elem := range elemsContainer.elems { 491 // populate header fields 492 header := elem.buffer[:MessageTransportHeaderSize] 493 494 fieldType := header[0:4] 495 fieldReceiver := header[4:8] 496 fieldNonce := header[8:16] 497 498 binary.LittleEndian.PutUint32(fieldType, MessageTransportType) 499 binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) 500 binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) 501 502 // pad content to multiple of 16 bytes 503 paddingSize := calculatePaddingSize(len(elem.packet), transport.sourceSink.MTU()) 504 elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) 505 506 // encrypt content and release to consumer 507 508 binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) 509 elem.packet = elem.keypair.send.Seal( 510 header, 511 nonce[:], 512 elem.packet, 513 nil, 514 ) 515 } 516 elemsContainer.Unlock() 517 } 518 } 519 520 func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { 521 logger := peer.transport.logger.With(slog.String("peer", peer.String())) 522 523 transport := peer.transport 524 defer func() { 525 defer logger.Debug("Routine: sequential sender - stopped") 526 peer.stopping.Done() 527 }() 528 logger.Debug("Routine: sequential sender - started") 529 530 bufs := make([][]byte, 0, maxBatchSize) 531 532 for elemsContainer := range peer.queue.outbound.c { 533 bufs = bufs[:0] 534 if elemsContainer == nil { 535 return 536 } 537 if !peer.isRunning.Load() { 538 // peer has been stopped; return re-usable elems to the shared pool. 539 // This is an optimization only. It is possible for the peer to be stopped 540 // immediately after this check, in which case, elem will get processed. 541 // The timers and SendBuffers code are resilient to a few stragglers. 542 // TODO: rework peer shutdown order to ensure 543 // that we never accidentally keep timers alive longer than necessary. 544 elemsContainer.Lock() 545 for _, elem := range elemsContainer.elems { 546 transport.PutMessageBuffer(elem.buffer) 547 transport.PutOutboundElement(elem) 548 } 549 continue 550 } 551 dataSent := false 552 elemsContainer.Lock() 553 for _, elem := range elemsContainer.elems { 554 if len(elem.packet) != MessageKeepaliveSize { 555 dataSent = true 556 } 557 bufs = append(bufs, elem.packet) 558 } 559 560 peer.timersAnyAuthenticatedPacketTraversal() 561 peer.timersAnyAuthenticatedPacketSent() 562 563 err := peer.SendBuffers(bufs) 564 if dataSent { 565 peer.timersDataSent() 566 } 567 for _, elem := range elemsContainer.elems { 568 transport.PutMessageBuffer(elem.buffer) 569 transport.PutOutboundElement(elem) 570 } 571 transport.PutOutboundElementsContainer(elemsContainer) 572 if err != nil { 573 var errGSO conn.ErrUDPGSODisabled 574 if errors.As(err, &errGSO) { 575 logger.Warn("Failed to send data packets, retrying", slog.Any("error", err)) 576 err = errGSO.RetryErr 577 } 578 } 579 if err != nil { 580 logger.Error("Failed to send data packets", slog.Any("error", err)) 581 continue 582 } 583 584 if err := peer.keepKeyFreshSending(); err != nil { 585 logger.Error("Failed to keep key fresh", slog.Any("error", err)) 586 } 587 } 588 }