github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/receive.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package transport 33 34 import ( 35 "bytes" 36 "encoding/binary" 37 "errors" 38 "log/slog" 39 "net" 40 "sync" 41 "time" 42 43 "github.com/noisysockets/noisysockets/internal/conn" 44 "github.com/noisysockets/noisysockets/types" 45 "golang.org/x/crypto/chacha20poly1305" 46 ) 47 48 type QueueHandshakeElement struct { 49 msgType uint32 50 packet []byte 51 endpoint conn.Endpoint 52 buffer *[MaxMessageSize]byte 53 } 54 55 type QueueInboundElement struct { 56 buffer *[MaxMessageSize]byte 57 packet []byte 58 counter uint64 59 keypair *Keypair 60 endpoint conn.Endpoint 61 } 62 63 type QueueInboundElementsContainer struct { 64 sync.Mutex 65 elems []*QueueInboundElement 66 } 67 68 // clearPointers clears elem fields that contain pointers. 69 // This makes the garbage collector's life easier and 70 // avoids accidentally keeping other objects around unnecessarily. 71 // It also reduces the possible collateral damage from use-after-free bugs. 72 func (elem *QueueInboundElement) clearPointers() { 73 elem.buffer = nil 74 elem.packet = nil 75 elem.keypair = nil 76 elem.endpoint = nil 77 } 78 79 /* Called when a new authenticated message has been received 80 * 81 * NOTE: Not thread safe, but called by sequential receiver! 82 */ 83 func (peer *Peer) keepKeyFreshReceiving() error { 84 if peer.timers.sentLastMinuteHandshake.Load() { 85 return nil 86 } 87 88 keypair := peer.keypairs.Current() 89 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { 90 peer.timers.sentLastMinuteHandshake.Store(true) 91 if err := peer.SendHandshakeInitiation(false); err != nil { 92 return err 93 } 94 } 95 96 return nil 97 } 98 99 /* Receives incoming datagrams for the transport 100 * 101 * Every time the bind is updated a new routine is started for 102 * IPv4 and IPv6 (separately) 103 */ 104 func (transport *Transport) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { 105 recvName := recv.PrettyName() 106 defer func() { 107 transport.logger.Debug("Routine: receive incoming - stopped", slog.String("recvName", recvName)) 108 transport.queue.decryption.wg.Done() 109 transport.queue.handshake.wg.Done() 110 transport.net.stopping.Done() 111 }() 112 113 transport.logger.Debug("Routine: receive incoming - started", slog.String("recvName", recvName)) 114 115 // receive datagrams until conn is closed 116 117 var ( 118 bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) 119 bufs = make([][]byte, maxBatchSize) 120 err error 121 sizes = make([]int, maxBatchSize) 122 count int 123 endpoints = make([]conn.Endpoint, maxBatchSize) 124 deathSpiral int 125 elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) 126 ) 127 128 for i := range bufsArrs { 129 bufsArrs[i] = transport.GetMessageBuffer() 130 bufs[i] = bufsArrs[i][:] 131 } 132 133 defer func() { 134 for i := 0; i < maxBatchSize; i++ { 135 if bufsArrs[i] != nil { 136 transport.PutMessageBuffer(bufsArrs[i]) 137 } 138 } 139 }() 140 141 for { 142 count, err = recv(bufs, sizes, endpoints) 143 if err != nil { 144 if errors.Is(err, net.ErrClosed) { 145 return 146 } 147 transport.logger.Warn("Failed to receive packet", 148 slog.String("recvName", recvName), 149 slog.Any("error", err)) 150 if deathSpiral < 10 { 151 deathSpiral++ 152 time.Sleep(time.Second / 3) 153 continue 154 } 155 return 156 } 157 deathSpiral = 0 158 159 // handle each packet in the batch 160 for i, size := range sizes[:count] { 161 if size < MinMessageSize { 162 continue 163 } 164 165 // check size of packet 166 167 packet := bufsArrs[i][:size] 168 msgType := binary.LittleEndian.Uint32(packet[:4]) 169 170 switch msgType { 171 172 // check if transport 173 174 case MessageTransportType: 175 176 // check size 177 178 if len(packet) < MessageTransportSize { 179 continue 180 } 181 182 // lookup key pair 183 184 receiver := binary.LittleEndian.Uint32( 185 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], 186 ) 187 value := transport.indexTable.Lookup(receiver) 188 keypair := value.keypair 189 if keypair == nil { 190 continue 191 } 192 193 // check keypair expiry 194 195 if keypair.created.Add(RejectAfterTime).Before(time.Now()) { 196 continue 197 } 198 199 // create work element 200 peer := value.peer 201 elem := transport.GetInboundElement() 202 elem.packet = packet 203 elem.buffer = bufsArrs[i] 204 elem.keypair = keypair 205 elem.endpoint = endpoints[i] 206 elem.counter = 0 207 208 elemsForPeer, ok := elemsByPeer[peer] 209 if !ok { 210 elemsForPeer = transport.GetInboundElementsContainer() 211 elemsForPeer.Lock() 212 elemsByPeer[peer] = elemsForPeer 213 } 214 elemsForPeer.elems = append(elemsForPeer.elems, elem) 215 bufsArrs[i] = transport.GetMessageBuffer() 216 bufs[i] = bufsArrs[i][:] 217 continue 218 219 // otherwise it is a fixed size & handshake related packet 220 221 case MessageInitiationType: 222 if len(packet) != MessageInitiationSize { 223 continue 224 } 225 226 case MessageResponseType: 227 if len(packet) != MessageResponseSize { 228 continue 229 } 230 231 case MessageCookieReplyType: 232 if len(packet) != MessageCookieReplySize { 233 continue 234 } 235 236 default: 237 transport.logger.Warn("Received message with unknown type", 238 slog.Int("type", int(msgType))) 239 continue 240 } 241 242 select { 243 case transport.queue.handshake.c <- QueueHandshakeElement{ 244 msgType: msgType, 245 buffer: bufsArrs[i], 246 packet: packet, 247 endpoint: endpoints[i], 248 }: 249 bufsArrs[i] = transport.GetMessageBuffer() 250 bufs[i] = bufsArrs[i][:] 251 default: 252 } 253 } 254 for peer, elemsContainer := range elemsByPeer { 255 if peer.isRunning.Load() { 256 peer.queue.inbound.c <- elemsContainer 257 transport.queue.decryption.c <- elemsContainer 258 } else { 259 for _, elem := range elemsContainer.elems { 260 transport.PutMessageBuffer(elem.buffer) 261 transport.PutInboundElement(elem) 262 } 263 transport.PutInboundElementsContainer(elemsContainer) 264 } 265 delete(elemsByPeer, peer) 266 } 267 } 268 } 269 270 func (transport *Transport) RoutineDecryption(id int) { 271 var nonce [chacha20poly1305.NonceSize]byte 272 273 defer transport.logger.Debug("Routine: decryption worker - stopped", slog.Int("id", id)) 274 transport.logger.Debug("Routine: decryption worker - started", slog.Int("id", id)) 275 276 for elemsContainer := range transport.queue.decryption.c { 277 for _, elem := range elemsContainer.elems { 278 // split message into fields 279 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] 280 content := elem.packet[MessageTransportOffsetContent:] 281 282 // decrypt and release to consumer 283 var err error 284 elem.counter = binary.LittleEndian.Uint64(counter) 285 // copy counter to nonce 286 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) 287 elem.packet, err = elem.keypair.receive.Open( 288 content[:0], 289 nonce[:], 290 content, 291 nil, 292 ) 293 if err != nil { 294 elem.packet = nil 295 } 296 } 297 elemsContainer.Unlock() 298 } 299 } 300 301 // Handles incoming packets related to handshake. 302 func (transport *Transport) RoutineHandshake(id int) { 303 logger := transport.logger.With(slog.Int("id", id)) 304 305 defer func() { 306 logger.Debug("Routine: handshake worker - stopped") 307 transport.queue.encryption.wg.Done() 308 }() 309 logger.Debug("Routine: handshake worker - started") 310 311 for elem := range transport.queue.handshake.c { 312 logger := logger.With(slog.String("from", elem.endpoint.DstToString())) 313 314 // handle cookie fields and ratelimiting 315 316 switch elem.msgType { 317 318 case MessageCookieReplyType: 319 320 // unmarshal packet 321 322 var reply MessageCookieReply 323 reader := bytes.NewReader(elem.packet) 324 err := binary.Read(reader, binary.LittleEndian, &reply) 325 if err != nil { 326 logger.Warn("Failed to decode cookie reply", slog.Any("error", err)) 327 goto skip 328 } 329 330 // lookup peer from index 331 332 entry := transport.indexTable.Lookup(reply.Receiver) 333 334 if entry.peer == nil { 335 goto skip 336 } 337 338 // consume reply 339 340 if peer := entry.peer; peer.isRunning.Load() { 341 logger.Debug("Receiving cookie response") 342 if !peer.cookieGenerator.ConsumeReply(&reply) { 343 logger.Warn("Could not decrypt invalid cookie response") 344 } 345 } 346 347 goto skip 348 349 case MessageInitiationType, MessageResponseType: 350 351 // check mac fields and maybe ratelimit 352 353 if !transport.cookieChecker.CheckMAC1(elem.packet) { 354 logger.Warn("Received packet with invalid mac1") 355 goto skip 356 } 357 358 // endpoints destination address is the source of the datagram 359 360 if transport.IsUnderLoad() { 361 362 // verify MAC2 field 363 364 if !transport.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { 365 if err := transport.SendHandshakeCookie(&elem); err != nil { 366 logger.Warn("Failed to send handshake cookie", slog.Any("error", err)) 367 } 368 goto skip 369 } 370 371 // check ratelimiter 372 373 if !transport.rate.limiter.Allow(elem.endpoint.DstIP()) { 374 goto skip 375 } 376 } 377 378 default: 379 logger.Warn("Invalid packet ended up in the handshake queue") 380 goto skip 381 } 382 383 // handle handshake initiation/response content 384 385 switch elem.msgType { 386 case MessageInitiationType: 387 388 // unmarshal 389 390 var msg MessageInitiation 391 reader := bytes.NewReader(elem.packet) 392 err := binary.Read(reader, binary.LittleEndian, &msg) 393 if err != nil { 394 logger.Warn("Failed to decode initiation message", slog.Any("error", err)) 395 goto skip 396 } 397 398 // consume initiation 399 400 peer := transport.ConsumeMessageInitiation(&msg) 401 if peer == nil { 402 logger.Warn("Received invalid initiation message") 403 goto skip 404 } 405 406 // update timers 407 408 peer.timersAnyAuthenticatedPacketTraversal() 409 peer.timersAnyAuthenticatedPacketReceived() 410 411 // update endpoint 412 peer.SetEndpoint(elem.endpoint) 413 414 logger.Debug("Received handshake initiation", slog.String("peer", peer.String())) 415 peer.rxBytes.Add(uint64(len(elem.packet))) 416 417 if err := peer.SendHandshakeResponse(); err != nil { 418 logger.Error("Failed to send handshake response", slog.Any("error", err)) 419 goto skip 420 } 421 422 case MessageResponseType: 423 424 // unmarshal 425 426 var msg MessageResponse 427 reader := bytes.NewReader(elem.packet) 428 err := binary.Read(reader, binary.LittleEndian, &msg) 429 if err != nil { 430 logger.Warn("Failed to decode response message", slog.Any("error", err)) 431 goto skip 432 } 433 434 // consume response 435 436 peer := transport.ConsumeMessageResponse(&msg) 437 if peer == nil { 438 logger.Warn("Received invalid response message") 439 goto skip 440 } 441 442 logger := logger.With(slog.String("peer", peer.String())) 443 444 // update endpoint 445 peer.SetEndpoint(elem.endpoint) 446 447 logger.Debug("Received handshake response") 448 peer.rxBytes.Add(uint64(len(elem.packet))) 449 450 // update timers 451 452 peer.timersAnyAuthenticatedPacketTraversal() 453 peer.timersAnyAuthenticatedPacketReceived() 454 455 // derive keypair 456 if err := peer.BeginSymmetricSession(); err != nil { 457 logger.Error("Failed to derive keypair", slog.Any("error", err)) 458 goto skip 459 } 460 461 peer.timersSessionDerived() 462 peer.timersHandshakeComplete() 463 if err := peer.SendKeepalive(); err != nil { 464 logger.Error("Failed to send keepalive", slog.Any("error", err)) 465 goto skip 466 } 467 } 468 skip: 469 transport.PutMessageBuffer(elem.buffer) 470 } 471 } 472 473 func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { 474 t := peer.transport 475 476 logger := t.logger.With(slog.String("peer", peer.String())) 477 478 defer func() { 479 logger.Debug("Routine: sequential receiver - stopped") 480 peer.stopping.Done() 481 }() 482 logger.Debug("Routine: sequential receiver - started") 483 484 bufs := make([][]byte, 0, maxBatchSize) 485 486 peers := make([]types.NoisePublicKey, 0, maxBatchSize) 487 for i := 0; i < maxBatchSize; i++ { 488 peers = append(peers, peer.pk) 489 } 490 491 for elemsContainer := range peer.queue.inbound.c { 492 if elemsContainer == nil { 493 return 494 } 495 elemsContainer.Lock() 496 validTailPacket := -1 497 dataPacketReceived := false 498 rxBytesLen := uint64(0) 499 for i, elem := range elemsContainer.elems { 500 if elem.packet == nil { 501 // decryption failed 502 continue 503 } 504 505 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { 506 continue 507 } 508 509 validTailPacket = i 510 if peer.ReceivedWithKeypair(elem.keypair) { 511 peer.SetEndpoint(elem.endpoint) 512 peer.timersHandshakeComplete() 513 if err := peer.SendStagedPackets(); err != nil { 514 logger.Warn("Failed to send staged packets", slog.Any("error", err)) 515 continue 516 } 517 } 518 rxBytesLen += uint64(len(elem.packet) + MinMessageSize) 519 520 if len(elem.packet) == 0 { 521 logger.Debug("Receiving keepalive packet") 522 continue 523 } 524 dataPacketReceived = true 525 526 bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) 527 } 528 529 peer.rxBytes.Add(rxBytesLen) 530 if validTailPacket >= 0 { 531 peer.SetEndpoint(elemsContainer.elems[validTailPacket].endpoint) 532 if err := peer.keepKeyFreshReceiving(); err != nil { 533 logger.Warn("Failed to keep key fresh", slog.Any("error", err)) 534 continue 535 } 536 peer.timersAnyAuthenticatedPacketTraversal() 537 peer.timersAnyAuthenticatedPacketReceived() 538 } 539 if dataPacketReceived { 540 peer.timersDataReceived() 541 } 542 if len(bufs) > 0 { 543 _, err := t.sourceSink.Write(bufs, peers, MessageTransportOffsetContent) 544 if err != nil && !t.isClosed() { 545 logger.Error("Failed to write packets to source sink", slog.Any("error", err)) 546 } 547 } 548 for _, elem := range elemsContainer.elems { 549 t.PutMessageBuffer(elem.buffer) 550 t.PutInboundElement(elem) 551 } 552 bufs = bufs[:0] 553 t.PutInboundElementsContainer(elemsContainer) 554 } 555 }