github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/source_sink.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package noisysockets 33 34 import ( 35 "context" 36 "fmt" 37 "hash/fnv" 38 "log/slog" 39 "net" 40 "net/netip" 41 "syscall" 42 43 "github.com/noisysockets/netstack/pkg/buffer" 44 "github.com/noisysockets/netstack/pkg/tcpip" 45 "github.com/noisysockets/netstack/pkg/tcpip/header" 46 "github.com/noisysockets/netstack/pkg/tcpip/link/channel" 47 "github.com/noisysockets/netstack/pkg/tcpip/network/ipv4" 48 "github.com/noisysockets/netstack/pkg/tcpip/network/ipv6" 49 "github.com/noisysockets/netstack/pkg/tcpip/stack" 50 "github.com/noisysockets/noisysockets/internal/conn" 51 "github.com/noisysockets/noisysockets/internal/transport" 52 "github.com/noisysockets/noisysockets/networkutil" 53 "github.com/noisysockets/noisysockets/types" 54 ) 55 56 const ( 57 queueSize = 1024 58 ) 59 60 var ( 61 _ transport.SourceSink = (*sourceSink)(nil) 62 ) 63 64 type sourceSink struct { 65 logger *slog.Logger 66 debugLogging bool 67 rt *routingTable 68 stack *stack.Stack 69 ep *channel.Endpoint 70 notifyHandle *channel.NotificationHandle 71 incoming chan *stack.PacketBuffer 72 interfaceAddrs []netip.Addr 73 } 74 75 func newSourceSink(logger *slog.Logger, rt *routingTable, s *stack.Stack, 76 interfaceAddrs []netip.Addr) (*sourceSink, error) { 77 ss := &sourceSink{ 78 logger: logger, 79 debugLogging: logger.Enabled(context.Background(), slog.LevelDebug), 80 rt: rt, 81 stack: s, 82 ep: channel.New(queueSize, uint32(transport.DefaultMTU), ""), 83 incoming: make(chan *stack.PacketBuffer), 84 interfaceAddrs: interfaceAddrs, 85 } 86 87 ss.notifyHandle = ss.ep.AddNotify(ss) 88 89 if err := ss.stack.CreateNIC(1, ss.ep); err != nil { 90 return nil, fmt.Errorf("could not create NIC: %v", err) 91 } 92 93 // Add default routes. 94 var routes []tcpip.Route 95 if networkutil.HasIPv4(interfaceAddrs) { 96 routes = append(routes, tcpip.Route{ 97 NIC: 1, 98 Destination: header.IPv4EmptySubnet, 99 }) 100 } 101 if networkutil.HasIPv6(interfaceAddrs) { 102 routes = append(routes, tcpip.Route{ 103 NIC: 1, 104 Destination: header.IPv6EmptySubnet, 105 }) 106 } 107 ss.stack.SetRouteTable(routes) 108 109 // Assign local addresses to the nic. 110 for _, addr := range interfaceAddrs { 111 var protoNumber tcpip.NetworkProtocolNumber 112 if addr.Is4() { 113 protoNumber = ipv4.ProtocolNumber 114 } else if addr.Is6() { 115 protoNumber = ipv6.ProtocolNumber 116 } 117 118 protoAddr := tcpip.ProtocolAddress{ 119 Protocol: protoNumber, 120 AddressWithPrefix: tcpip.AddrFromSlice(addr.AsSlice()).WithPrefix(), 121 } 122 123 logger.Debug("Adding local address", slog.String("address", addr.String())) 124 125 if err := ss.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { 126 return nil, fmt.Errorf("could not add address: %v", err) 127 } 128 } 129 130 return ss, nil 131 } 132 133 func (ss *sourceSink) Close() error { 134 ss.ep.RemoveNotify(ss.notifyHandle) 135 ss.ep.Close() 136 137 ss.stack.RemoveNIC(1) 138 139 // Drain the incoming channel before closing. 140 ss.drain() 141 142 close(ss.incoming) 143 144 return nil 145 } 146 147 func (ss *sourceSink) Read(bufs [][]byte, sizes []int, destinations []types.NoisePublicKey, offset int) (int, error) { 148 packetFn := func(idx int, pkt *stack.PacketBuffer) error { 149 defer pkt.DecRef() 150 151 if ss.debugLogging { 152 ss.logger.Debug("Processing netstack packet", 153 slog.Uint64("packetHash", hashPacketMetadata(pkt))) 154 } 155 156 // Extract the destination IP address from the packet 157 var peerAddr netip.Addr 158 switch pkt.NetworkProtocolNumber { 159 case header.IPv4ProtocolNumber: 160 hdr := header.IPv4(pkt.NetworkHeader().Slice()) 161 if !hdr.IsValid(pkt.Size()) { 162 return fmt.Errorf("invalid IPv4 header") 163 } 164 165 peerAddr = netip.AddrFrom4(hdr.DestinationAddress().As4()) 166 case header.IPv6ProtocolNumber: 167 hdr := header.IPv6(pkt.NetworkHeader().Slice()) 168 if !hdr.IsValid(pkt.Size()) { 169 return fmt.Errorf("invalid IPv6 header") 170 } 171 172 peerAddr = netip.AddrFrom16(hdr.DestinationAddress().As16()) 173 default: 174 return fmt.Errorf("unknown network protocol: %w", syscall.EAFNOSUPPORT) 175 } 176 177 logger := ss.logger.With(slog.String("address", peerAddr.String())) 178 179 dstPeer, ok := ss.rt.destination(peerAddr) 180 if !ok { 181 return fmt.Errorf("unknown destination address") 182 } 183 destinations[idx] = dstPeer.PublicKey() 184 185 if ss.debugLogging { 186 logger.Debug("Sending packet to peer", 187 slog.String("destination", destinations[idx].DisplayString())) 188 } 189 190 view := pkt.ToView() 191 n, err := view.Read(bufs[idx][offset:]) 192 view.Release() 193 if err != nil { 194 return fmt.Errorf("could not read packet: %w", err) 195 } 196 197 sizes[idx] = n 198 199 return nil 200 } 201 202 // Always block until we have at least one packet. 203 var count int 204 pkt, ok := <-ss.incoming 205 if !ok { 206 return 0, net.ErrClosed 207 } 208 209 if err := packetFn(count, pkt); err != nil { 210 ss.logger.Warn("Failed to process packet", slog.Any("error", err)) 211 return count, err 212 } 213 214 count++ 215 216 for count < len(bufs) { 217 select { 218 case pkt, ok := <-ss.incoming: 219 if !ok { 220 return count, net.ErrClosed 221 } 222 223 if err := packetFn(count, pkt); err != nil { 224 ss.logger.Warn("Failed to process packet", slog.Any("error", err)) 225 return count, err 226 } 227 228 count++ 229 default: 230 return count, nil 231 } 232 } 233 234 return count, nil 235 } 236 237 func (ss *sourceSink) Write(bufs [][]byte, sources []types.NoisePublicKey, offset int) (int, error) { 238 for i, buf := range bufs { 239 if len(buf) <= offset { 240 continue 241 } 242 243 if ss.debugLogging { 244 ss.logger.Debug("Received packet from peer", slog.String("source", sources[i].DisplayString())) 245 } 246 247 // Validate the source address (to prevent spoofing). 248 protocolNumber, err := ss.validateSourceAddress(buf[offset:], sources[i]) 249 if err != nil { 250 return i, err 251 } 252 253 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(buf[offset:])}) 254 255 if ss.debugLogging { 256 ss.logger.Debug("Injecting inbound packet into netstack") 257 } 258 259 ss.ep.InjectInbound(protocolNumber, pkt) 260 } 261 262 return len(bufs), nil 263 } 264 265 func (ss *sourceSink) MTU() int { 266 return int(ss.ep.MTU()) 267 } 268 269 func (ss *sourceSink) BatchSize() int { 270 return conn.IdealBatchSize 271 } 272 273 func (ss *sourceSink) WriteNotify() { 274 pkt := ss.ep.Read() 275 if pkt == nil { 276 return 277 } 278 279 if ss.debugLogging { 280 ss.logger.Debug("Received outbound packet from netstack", 281 slog.Uint64("packetHash", hashPacketMetadata(pkt))) 282 } 283 284 ss.incoming <- pkt 285 } 286 287 func (ss *sourceSink) drain() { 288 for { 289 select { 290 case pkt, ok := <-ss.incoming: 291 if !ok { 292 return 293 } 294 295 pkt.DecRef() 296 default: 297 return 298 } 299 } 300 } 301 302 func (ss *sourceSink) validateSourceAddress(buf []byte, source types.NoisePublicKey) (tcpip.NetworkProtocolNumber, error) { 303 var protocolNumber tcpip.NetworkProtocolNumber 304 switch header.IPVersion(buf) { 305 case header.IPv4Version: 306 protocolNumber = header.IPv4ProtocolNumber 307 case header.IPv6Version: 308 protocolNumber = header.IPv6ProtocolNumber 309 default: 310 return 0, fmt.Errorf("unknown IP version: %w", syscall.EAFNOSUPPORT) 311 } 312 313 var addr netip.Addr 314 switch protocolNumber { 315 case header.IPv4ProtocolNumber: 316 hdr := header.IPv4(buf) 317 if !hdr.IsValid(len(buf)) { 318 return protocolNumber, fmt.Errorf("invalid IPv4 header") 319 } 320 321 addr = netip.AddrFrom4(hdr.SourceAddress().As4()) 322 case header.IPv6ProtocolNumber: 323 hdr := header.IPv6(buf) 324 if !hdr.IsValid(len(buf)) { 325 return protocolNumber, fmt.Errorf("invalid IPv6 header") 326 } 327 328 addr = netip.AddrFrom16(hdr.SourceAddress().As16()) 329 default: 330 return protocolNumber, fmt.Errorf("unknown network protocol: %w", syscall.EAFNOSUPPORT) 331 } 332 333 expectedSrcPeer, ok := ss.rt.destination(addr) 334 if !ok { 335 return protocolNumber, fmt.Errorf("unknown source address") 336 } 337 338 if !expectedSrcPeer.PublicKey().Equals(source) { 339 return protocolNumber, fmt.Errorf("invalid source address for peer") 340 } 341 342 return protocolNumber, nil 343 } 344 345 func hashPacketMetadata(pkt *stack.PacketBuffer) uint64 { 346 h := fnv.New64a() 347 _, _ = h.Write(pkt.NetworkHeader().Slice()) 348 _, _ = h.Write(pkt.TransportHeader().Slice()) 349 return h.Sum64() 350 }