github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/peer.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package transport 33 34 import ( 35 "errors" 36 "fmt" 37 "log/slog" 38 "sync" 39 "sync/atomic" 40 "time" 41 42 "github.com/noisysockets/noisysockets/internal/conn" 43 "github.com/noisysockets/noisysockets/types" 44 ) 45 46 type Peer struct { 47 pk types.NoisePublicKey 48 isRunning atomic.Bool 49 keypairs Keypairs 50 handshake Handshake 51 transport *Transport 52 stopping sync.WaitGroup // routines pending stop 53 txBytes atomic.Uint64 // bytes send to peer (endpoint) 54 rxBytes atomic.Uint64 // bytes received from peer 55 lastHandshakeNano atomic.Int64 // nano seconds since epoch 56 57 endpoint struct { 58 sync.Mutex 59 val conn.Endpoint 60 } 61 62 timers struct { 63 retransmitHandshake *Timer 64 sendKeepalive *Timer 65 newHandshake *Timer 66 zeroKeyMaterial *Timer 67 keepAlive *Timer 68 handshakeAttempts atomic.Uint32 69 needAnotherKeepalive atomic.Bool 70 sentLastMinuteHandshake atomic.Bool 71 } 72 73 state struct { 74 sync.Mutex // protects against concurrent Start/Stop 75 } 76 77 queue struct { 78 staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available 79 outbound *autodrainingOutboundQueue // sequential ordering of udp transmission 80 inbound *autodrainingInboundQueue // sequential ordering of sink writing 81 } 82 83 cookieGenerator CookieGenerator 84 keepAliveInterval atomic.Uint32 85 } 86 87 func (transport *Transport) NewPeer(pk types.NoisePublicKey) (*Peer, error) { 88 if transport.isClosed() { 89 return nil, errors.New("transport closed") 90 } 91 92 // lock resources 93 transport.staticIdentity.RLock() 94 defer transport.staticIdentity.RUnlock() 95 96 transport.peers.Lock() 97 defer transport.peers.Unlock() 98 99 // check if over limit 100 if len(transport.peers.keyMap) >= MaxPeers { 101 return nil, errors.New("too many peers") 102 } 103 104 // create peer 105 peer := new(Peer) 106 107 peer.pk = pk 108 peer.cookieGenerator.Init(pk) 109 peer.transport = transport 110 peer.queue.outbound = newAutodrainingOutboundQueue(transport) 111 peer.queue.inbound = newAutodrainingInboundQueue(transport) 112 peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) 113 114 // map public key 115 _, ok := transport.peers.keyMap[pk] 116 if ok { 117 return nil, errors.New("adding existing peer") 118 } 119 120 // pre-compute DH 121 handshake := &peer.handshake 122 handshake.mutex.Lock() 123 handshake.precomputedStaticStatic, _ = sharedSecret(transport.staticIdentity.privateKey, pk) 124 handshake.remoteStatic = pk 125 handshake.mutex.Unlock() 126 127 // reset endpoint 128 peer.endpoint.Lock() 129 peer.endpoint.val = nil 130 peer.endpoint.Unlock() 131 132 // init timers 133 peer.timersInit() 134 135 // add 136 transport.peers.keyMap[pk] = peer 137 138 return peer, nil 139 } 140 141 func (peer *Peer) SendBuffers(buffers [][]byte) error { 142 peer.transport.net.RLock() 143 defer peer.transport.net.RUnlock() 144 145 if peer.transport.isClosed() { 146 return nil 147 } 148 149 peer.endpoint.Lock() 150 endpoint := peer.endpoint.val 151 if endpoint == nil { 152 peer.endpoint.Unlock() 153 return errors.New("no known endpoint for peer") 154 } 155 peer.endpoint.Unlock() 156 157 err := peer.transport.net.bind.Send(buffers, endpoint) 158 if err == nil { 159 var totalLen uint64 160 for _, b := range buffers { 161 totalLen += uint64(len(b)) 162 } 163 peer.txBytes.Add(totalLen) 164 } 165 return err 166 } 167 168 func (peer *Peer) String() string { 169 return fmt.Sprintf("peer(%s)", peer.handshake.remoteStatic.DisplayString()) 170 } 171 172 func (peer *Peer) Start() { 173 // should never start a peer on a closed transport. 174 if peer.transport.isClosed() { 175 return 176 } 177 178 // prevent simultaneous start/stop operations 179 peer.state.Lock() 180 defer peer.state.Unlock() 181 182 if peer.isRunning.Load() { 183 return 184 } 185 186 transport := peer.transport 187 transport.logger.Debug("Starting", slog.String("peer", peer.String())) 188 189 // reset routine state 190 peer.stopping.Wait() 191 peer.stopping.Add(2) 192 193 peer.handshake.mutex.Lock() 194 peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 195 peer.handshake.mutex.Unlock() 196 197 peer.transport.queue.encryption.wg.Add(1) // keep encryption queue open for our writes 198 199 peer.timersStart() 200 201 transport.flushInboundQueue(peer.queue.inbound) 202 transport.flushOutboundQueue(peer.queue.outbound) 203 204 // Use the transport batch size, not the bind batch size, as the transport size is 205 // the size of the batch pools. 206 batchSize := peer.transport.BatchSize() 207 go peer.RoutineSequentialSender(batchSize) 208 go peer.RoutineSequentialReceiver(batchSize) 209 210 peer.isRunning.Store(true) 211 } 212 213 func (peer *Peer) ZeroAndFlushAll() { 214 transport := peer.transport 215 216 // clear key pairs 217 218 keypairs := &peer.keypairs 219 keypairs.Lock() 220 transport.DeleteKeypair(keypairs.previous) 221 transport.DeleteKeypair(keypairs.current) 222 transport.DeleteKeypair(keypairs.next.Load()) 223 keypairs.previous = nil 224 keypairs.current = nil 225 keypairs.next.Store(nil) 226 keypairs.Unlock() 227 228 // clear handshake state 229 230 handshake := &peer.handshake 231 handshake.mutex.Lock() 232 transport.indexTable.Delete(handshake.localIndex) 233 handshake.Clear() 234 handshake.mutex.Unlock() 235 236 peer.FlushStagedPackets() 237 } 238 239 func (peer *Peer) ExpireCurrentKeypairs() { 240 handshake := &peer.handshake 241 handshake.mutex.Lock() 242 peer.transport.indexTable.Delete(handshake.localIndex) 243 handshake.Clear() 244 peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 245 handshake.mutex.Unlock() 246 247 keypairs := &peer.keypairs 248 keypairs.Lock() 249 if keypairs.current != nil { 250 keypairs.current.sendNonce.Store(RejectAfterMessages) 251 } 252 if next := keypairs.next.Load(); next != nil { 253 next.sendNonce.Store(RejectAfterMessages) 254 } 255 keypairs.Unlock() 256 } 257 258 func (peer *Peer) Stop() { 259 peer.state.Lock() 260 defer peer.state.Unlock() 261 262 if !peer.isRunning.Swap(false) { 263 return 264 } 265 266 peer.transport.logger.Debug("Stopping", slog.String("peer", peer.String())) 267 268 peer.timersStop() 269 // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. 270 peer.queue.inbound.c <- nil 271 peer.queue.outbound.c <- nil 272 peer.stopping.Wait() 273 peer.transport.queue.encryption.wg.Done() // no more writes to encryption queue from us 274 275 peer.ZeroAndFlushAll() 276 } 277 278 func (peer *Peer) GetEndpoint() conn.Endpoint { 279 peer.endpoint.Lock() 280 defer peer.endpoint.Unlock() 281 return peer.endpoint.val 282 } 283 284 func (peer *Peer) SetEndpoint(endpoint conn.Endpoint) { 285 peer.endpoint.Lock() 286 defer peer.endpoint.Unlock() 287 peer.endpoint.val = endpoint 288 } 289 290 func (peer *Peer) SetKeepAliveInterval(interval time.Duration) { 291 peer.keepAliveInterval.Store(uint32(interval.Seconds())) 292 }