github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/transport.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package transport 33 34 import ( 35 "fmt" 36 "log/slog" 37 "runtime" 38 "sync" 39 "sync/atomic" 40 "time" 41 42 "github.com/noisysockets/noisysockets/internal/conn" 43 "github.com/noisysockets/noisysockets/internal/ratelimiter" 44 "github.com/noisysockets/noisysockets/types" 45 ) 46 47 type Transport struct { 48 logger *slog.Logger 49 50 state struct { 51 // state holds the transport's state. It is accessed atomically. 52 // Use the transport.transportState method to read it. 53 // transport.transportState does not acquire the mutex, so it captures only a snapshot. 54 // During state transitions, the state variable is updated before the transport itself. 55 // The state is thus either the current state of the transport or 56 // the intended future state of the transport. 57 // For example, while executing a call to Up, state will be transportStateUp. 58 // There is no guarantee that that intended future state of the transport 59 // will become the actual state; Up can fail. 60 // The transport can also change state multiple times between time of check and time of use. 61 // Unsynchronized uses of state must therefore be advisory/best-effort only. 62 state atomic.Uint32 // actually a transportState, but typed uint32 for convenience 63 // stopping blocks until all inputs to Transport have been closed. 64 stopping sync.WaitGroup 65 // mu protects state changes. 66 sync.Mutex 67 } 68 69 net struct { 70 stopping sync.WaitGroup 71 sync.RWMutex 72 bind conn.Bind // bind interface 73 port uint16 // listening port 74 } 75 76 staticIdentity struct { 77 sync.RWMutex 78 privateKey types.NoisePrivateKey 79 publicKey types.NoisePublicKey 80 } 81 82 peers struct { 83 sync.RWMutex // protects keyMap 84 keyMap map[types.NoisePublicKey]*Peer 85 } 86 87 rate struct { 88 underLoadUntil atomic.Int64 89 limiter ratelimiter.Ratelimiter 90 } 91 92 indexTable IndexTable 93 cookieChecker CookieChecker 94 95 pool struct { 96 inboundElementsContainer *WaitPool 97 outboundElementsContainer *WaitPool 98 messageBuffers *WaitPool 99 inboundElements *WaitPool 100 outboundElements *WaitPool 101 } 102 103 queue struct { 104 encryption *outboundQueue 105 decryption *inboundQueue 106 handshake *handshakeQueue 107 } 108 109 sourceSink SourceSink 110 111 closed chan struct{} 112 } 113 114 // transportState represents the state of a Transport. 115 // There are three states: down, up, closed. 116 // Transitions: 117 // 118 // down -----+ 119 // ↑↓ ↓ 120 // up -> closed 121 type transportState uint32 122 123 const ( 124 transportStateDown transportState = iota 125 transportStateUp 126 transportStateClosed 127 ) 128 129 func (state transportState) String() string { 130 switch state { 131 case transportStateDown: 132 return "down" 133 case transportStateUp: 134 return "up" 135 case transportStateClosed: 136 return "closed" 137 default: 138 return "unknown" 139 } 140 } 141 142 // transportState returns transport.state.state as a transportState 143 // See those docs for how to interpret this value. 144 func (transport *Transport) transportState() transportState { 145 return transportState(transport.state.state.Load()) 146 } 147 148 // isClosed reports whether the transport is closed (or is closing). 149 // See transport.state.state comments for how to interpret this value. 150 func (transport *Transport) isClosed() bool { 151 return transport.transportState() == transportStateClosed 152 } 153 154 // isUp reports whether the transport is up (or is attempting to come up). 155 // See transport.state.state comments for how to interpret this value. 156 func (transport *Transport) isUp() bool { 157 return transport.transportState() == transportStateUp 158 } 159 160 // Must hold transport.peers.Lock() 161 func removePeerLocked(transport *Transport, peer *Peer, key types.NoisePublicKey) { 162 // stop routing and processing of packets 163 peer.Stop() 164 165 // remove from peer map 166 delete(transport.peers.keyMap, key) 167 } 168 169 // changeState attempts to change the transport state to match want. 170 func (transport *Transport) changeState(want transportState) (err error) { 171 transport.state.Lock() 172 defer transport.state.Unlock() 173 old := transport.transportState() 174 if old == transportStateClosed { 175 // once closed, always closed 176 transport.logger.Debug("Interface closed, ignored requested state", slog.String("want", want.String())) 177 return nil 178 } 179 switch want { 180 case old: 181 return nil 182 case transportStateUp: 183 transport.state.state.Store(uint32(transportStateUp)) 184 err = transport.upLocked() 185 if err == nil { 186 break 187 } 188 fallthrough // up failed; bring the transport all the way back down 189 case transportStateDown: 190 transport.state.state.Store(uint32(transportStateDown)) 191 errDown := transport.downLocked() 192 if err == nil { 193 err = errDown 194 } 195 } 196 transport.logger.Debug("Interface state change requested", 197 slog.String("old", old.String()), slog.String("want", want.String()), 198 slog.String("now", transport.transportState().String())) 199 return 200 } 201 202 // upLocked attempts to bring the transport up and reports whether it succeeded. 203 // The caller must hold transport.state.mu and is responsible for updating transport.state.state. 204 func (transport *Transport) upLocked() error { 205 if err := transport.BindUpdate(); err != nil { 206 transport.logger.Error("Unable to update bind", slog.Any("error", err)) 207 return err 208 } 209 210 transport.peers.RLock() 211 for _, peer := range transport.peers.keyMap { 212 peer.Start() 213 if peer.keepAliveInterval.Load() > 0 { 214 if err := peer.SendKeepalive(); err != nil { 215 transport.logger.Error("Failed to send keepalive", 216 slog.String("peer", peer.String()), slog.Any("error", err)) 217 } 218 } 219 } 220 transport.peers.RUnlock() 221 return nil 222 } 223 224 // downLocked attempts to bring the transport down. 225 // The caller must hold transport.state.mu and is responsible for updating transport.state.state. 226 func (transport *Transport) downLocked() error { 227 err := transport.BindClose() 228 if err != nil { 229 transport.logger.Error("Bind close failed", slog.Any("error", err)) 230 } 231 232 transport.peers.RLock() 233 for _, peer := range transport.peers.keyMap { 234 peer.Stop() 235 } 236 transport.peers.RUnlock() 237 return err 238 } 239 240 func (transport *Transport) Up() error { 241 return transport.changeState(transportStateUp) 242 } 243 244 func (transport *Transport) Down() error { 245 return transport.changeState(transportStateDown) 246 } 247 248 func (transport *Transport) IsUnderLoad() bool { 249 // check if currently under load 250 now := time.Now() 251 underLoad := len(transport.queue.handshake.c) >= QueueHandshakeSize/8 252 if underLoad { 253 transport.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) 254 return true 255 } 256 // check if recently under load 257 return transport.rate.underLoadUntil.Load() > now.UnixNano() 258 } 259 260 func (transport *Transport) SetPrivateKey(sk types.NoisePrivateKey) { 261 // lock required resources 262 263 transport.staticIdentity.Lock() 264 defer transport.staticIdentity.Unlock() 265 266 if sk.Equals(transport.staticIdentity.privateKey) { 267 return 268 } 269 270 transport.peers.Lock() 271 defer transport.peers.Unlock() 272 273 lockedPeers := make([]*Peer, 0, len(transport.peers.keyMap)) 274 for _, peer := range transport.peers.keyMap { 275 peer.handshake.mutex.RLock() 276 lockedPeers = append(lockedPeers, peer) 277 } 278 279 // remove peers with matching public keys 280 281 publicKey := sk.Public() 282 for key, peer := range transport.peers.keyMap { 283 if peer.handshake.remoteStatic.Equals(publicKey) { 284 peer.handshake.mutex.RUnlock() 285 removePeerLocked(transport, peer, key) 286 peer.handshake.mutex.RLock() 287 } 288 } 289 290 // update key material 291 292 transport.staticIdentity.privateKey = sk 293 transport.staticIdentity.publicKey = publicKey 294 transport.cookieChecker.Init(publicKey) 295 296 // do static-static DH pre-computations 297 298 expiredPeers := make([]*Peer, 0, len(transport.peers.keyMap)) 299 for _, peer := range transport.peers.keyMap { 300 handshake := &peer.handshake 301 handshake.precomputedStaticStatic, _ = sharedSecret(transport.staticIdentity.privateKey, handshake.remoteStatic) 302 expiredPeers = append(expiredPeers, peer) 303 } 304 305 for _, peer := range lockedPeers { 306 peer.handshake.mutex.RUnlock() 307 } 308 for _, peer := range expiredPeers { 309 peer.ExpireCurrentKeypairs() 310 } 311 } 312 313 func NewTransport(logger *slog.Logger, sourceSink SourceSink, bind conn.Bind) *Transport { 314 t := new(Transport) 315 t.state.state.Store(uint32(transportStateDown)) 316 t.closed = make(chan struct{}) 317 t.logger = logger 318 t.net.bind = bind 319 t.sourceSink = sourceSink 320 t.peers.keyMap = make(map[types.NoisePublicKey]*Peer) 321 t.rate.limiter.Init() 322 t.indexTable.Init() 323 324 t.PopulatePools() 325 326 // create queues 327 328 t.queue.handshake = newHandshakeQueue() 329 t.queue.encryption = newOutboundQueue() 330 t.queue.decryption = newInboundQueue() 331 332 // start workers 333 334 cpus := runtime.NumCPU() 335 t.state.stopping.Wait() 336 t.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake 337 for i := 0; i < cpus; i++ { 338 go t.RoutineEncryption(i + 1) 339 go t.RoutineDecryption(i + 1) 340 go t.RoutineHandshake(i + 1) 341 } 342 343 t.state.stopping.Add(1) 344 t.queue.encryption.wg.Add(1) 345 go t.RoutineReadFromSourceSink() 346 347 return t 348 } 349 350 // BatchSize returns the BatchSize for the transport as a whole which is the max of 351 // the bind batch size and the sink batch size. The batch size reported by transport 352 // is the size used to construct memory pools, and is the allowed batch size for 353 // the lifetime of the transport. 354 func (transport *Transport) BatchSize() int { 355 size := transport.net.bind.BatchSize() 356 dSize := transport.sourceSink.BatchSize() 357 if size < dSize { 358 size = dSize 359 } 360 return size 361 } 362 363 func (transport *Transport) Peers() []types.NoisePublicKey { 364 transport.peers.RLock() 365 defer transport.peers.RUnlock() 366 367 keys := make([]types.NoisePublicKey, 0, len(transport.peers.keyMap)) 368 for key := range transport.peers.keyMap { 369 keys = append(keys, key) 370 } 371 return keys 372 } 373 374 func (transport *Transport) LookupPeer(pk types.NoisePublicKey) *Peer { 375 transport.peers.RLock() 376 defer transport.peers.RUnlock() 377 378 return transport.peers.keyMap[pk] 379 } 380 381 func (transport *Transport) RemovePeer(pk types.NoisePublicKey) { 382 transport.peers.Lock() 383 defer transport.peers.Unlock() 384 // stop peer and remove from routing 385 386 peer, ok := transport.peers.keyMap[pk] 387 if ok { 388 removePeerLocked(transport, peer, pk) 389 } 390 } 391 392 func (transport *Transport) RemoveAllPeers() { 393 transport.peers.Lock() 394 defer transport.peers.Unlock() 395 396 for key, peer := range transport.peers.keyMap { 397 removePeerLocked(transport, peer, key) 398 } 399 400 transport.peers.keyMap = make(map[types.NoisePublicKey]*Peer) 401 } 402 403 func (transport *Transport) Close() error { 404 transport.state.Lock() 405 defer transport.state.Unlock() 406 if transport.isClosed() { 407 return nil 408 } 409 transport.state.state.Store(uint32(transportStateClosed)) 410 transport.logger.Debug("Transport closing") 411 412 _ = transport.sourceSink.Close() 413 _ = transport.downLocked() 414 415 // Remove peers before closing queues, 416 // because peers assume that queues are active. 417 transport.RemoveAllPeers() 418 419 // We kept a reference to the encryption and decryption queues, 420 // in case we started any new peers that might write to them. 421 // No new peers are coming; we are done with these queues. 422 transport.queue.encryption.wg.Done() 423 transport.queue.decryption.wg.Done() 424 transport.queue.handshake.wg.Done() 425 transport.state.stopping.Wait() 426 427 if err := transport.rate.limiter.Close(); err != nil { 428 return fmt.Errorf("failed to close rate limiter: %w", err) 429 } 430 431 transport.logger.Debug("Transport closed") 432 close(transport.closed) 433 434 return nil 435 } 436 437 func (transport *Transport) Wait() chan struct{} { 438 return transport.closed 439 } 440 441 func (transport *Transport) SendKeepalivesToPeersWithCurrentKeypair() { 442 if !transport.isUp() { 443 return 444 } 445 446 transport.peers.RLock() 447 for _, peer := range transport.peers.keyMap { 448 peer.keypairs.RLock() 449 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) 450 peer.keypairs.RUnlock() 451 if sendKeepalive { 452 if err := peer.SendKeepalive(); err != nil { 453 transport.logger.Error("Failed to send keepalive", 454 slog.String("peer", peer.String()), slog.Any("error", err)) 455 } 456 } 457 } 458 transport.peers.RUnlock() 459 } 460 461 // closeBindLocked closes the transport's net.bind. 462 // The caller must hold the net mutex. 463 func closeBindLocked(transport *Transport) error { 464 var err error 465 netc := &transport.net 466 if netc.bind != nil { 467 err = netc.bind.Close() 468 } 469 netc.stopping.Wait() 470 return err 471 } 472 473 func (transport *Transport) Bind() conn.Bind { 474 transport.net.Lock() 475 defer transport.net.Unlock() 476 return transport.net.bind 477 } 478 479 func (transport *Transport) BindUpdate() error { 480 transport.net.Lock() 481 defer transport.net.Unlock() 482 483 // close existing sockets 484 if err := closeBindLocked(transport); err != nil { 485 return err 486 } 487 488 // open new sockets 489 if !transport.isUp() { 490 return nil 491 } 492 493 // bind to new port 494 var err error 495 var recvFns []conn.ReceiveFunc 496 netc := &transport.net 497 498 recvFns, netc.port, err = netc.bind.Open(netc.port) 499 if err != nil { 500 netc.port = 0 501 return err 502 } 503 504 // start receiving routines 505 transport.net.stopping.Add(len(recvFns)) 506 transport.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to transport.queue.decryption 507 transport.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to transport.queue.handshake 508 batchSize := netc.bind.BatchSize() 509 for _, fn := range recvFns { 510 go transport.RoutineReceiveIncoming(batchSize, fn) 511 } 512 513 transport.logger.Debug("UDP bind has been updated") 514 return nil 515 } 516 517 func (transport *Transport) BindClose() error { 518 transport.net.Lock() 519 err := closeBindLocked(transport) 520 transport.net.Unlock() 521 return err 522 } 523 524 func (transport *Transport) UpdatePort(port uint16) error { 525 transport.net.Lock() 526 transport.net.port = port 527 transport.net.Unlock() 528 529 if err := transport.BindUpdate(); err != nil { 530 return fmt.Errorf("failed to update port: %w", err) 531 } 532 533 return nil 534 }