github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/source_sink.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package noisysockets
    33  
    34  import (
    35  	"context"
    36  	"fmt"
    37  	"hash/fnv"
    38  	"log/slog"
    39  	"net"
    40  	"net/netip"
    41  	"syscall"
    42  
    43  	"github.com/noisysockets/netstack/pkg/buffer"
    44  	"github.com/noisysockets/netstack/pkg/tcpip"
    45  	"github.com/noisysockets/netstack/pkg/tcpip/header"
    46  	"github.com/noisysockets/netstack/pkg/tcpip/link/channel"
    47  	"github.com/noisysockets/netstack/pkg/tcpip/network/ipv4"
    48  	"github.com/noisysockets/netstack/pkg/tcpip/network/ipv6"
    49  	"github.com/noisysockets/netstack/pkg/tcpip/stack"
    50  	"github.com/noisysockets/noisysockets/internal/conn"
    51  	"github.com/noisysockets/noisysockets/internal/transport"
    52  	"github.com/noisysockets/noisysockets/networkutil"
    53  	"github.com/noisysockets/noisysockets/types"
    54  )
    55  
    56  const (
    57  	queueSize = 1024
    58  )
    59  
    60  var (
    61  	_ transport.SourceSink = (*sourceSink)(nil)
    62  )
    63  
    64  type sourceSink struct {
    65  	logger         *slog.Logger
    66  	debugLogging   bool
    67  	rt             *routingTable
    68  	stack          *stack.Stack
    69  	ep             *channel.Endpoint
    70  	notifyHandle   *channel.NotificationHandle
    71  	incoming       chan *stack.PacketBuffer
    72  	interfaceAddrs []netip.Addr
    73  }
    74  
    75  func newSourceSink(logger *slog.Logger, rt *routingTable, s *stack.Stack,
    76  	interfaceAddrs []netip.Addr) (*sourceSink, error) {
    77  	ss := &sourceSink{
    78  		logger:         logger,
    79  		debugLogging:   logger.Enabled(context.Background(), slog.LevelDebug),
    80  		rt:             rt,
    81  		stack:          s,
    82  		ep:             channel.New(queueSize, uint32(transport.DefaultMTU), ""),
    83  		incoming:       make(chan *stack.PacketBuffer),
    84  		interfaceAddrs: interfaceAddrs,
    85  	}
    86  
    87  	ss.notifyHandle = ss.ep.AddNotify(ss)
    88  
    89  	if err := ss.stack.CreateNIC(1, ss.ep); err != nil {
    90  		return nil, fmt.Errorf("could not create NIC: %v", err)
    91  	}
    92  
    93  	// Add default routes.
    94  	var routes []tcpip.Route
    95  	if networkutil.HasIPv4(interfaceAddrs) {
    96  		routes = append(routes, tcpip.Route{
    97  			NIC:         1,
    98  			Destination: header.IPv4EmptySubnet,
    99  		})
   100  	}
   101  	if networkutil.HasIPv6(interfaceAddrs) {
   102  		routes = append(routes, tcpip.Route{
   103  			NIC:         1,
   104  			Destination: header.IPv6EmptySubnet,
   105  		})
   106  	}
   107  	ss.stack.SetRouteTable(routes)
   108  
   109  	// Assign local addresses to the nic.
   110  	for _, addr := range interfaceAddrs {
   111  		var protoNumber tcpip.NetworkProtocolNumber
   112  		if addr.Is4() {
   113  			protoNumber = ipv4.ProtocolNumber
   114  		} else if addr.Is6() {
   115  			protoNumber = ipv6.ProtocolNumber
   116  		}
   117  
   118  		protoAddr := tcpip.ProtocolAddress{
   119  			Protocol:          protoNumber,
   120  			AddressWithPrefix: tcpip.AddrFromSlice(addr.AsSlice()).WithPrefix(),
   121  		}
   122  
   123  		logger.Debug("Adding local address", slog.String("address", addr.String()))
   124  
   125  		if err := ss.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
   126  			return nil, fmt.Errorf("could not add address: %v", err)
   127  		}
   128  	}
   129  
   130  	return ss, nil
   131  }
   132  
   133  func (ss *sourceSink) Close() error {
   134  	ss.ep.RemoveNotify(ss.notifyHandle)
   135  	ss.ep.Close()
   136  
   137  	ss.stack.RemoveNIC(1)
   138  
   139  	// Drain the incoming channel before closing.
   140  	ss.drain()
   141  
   142  	close(ss.incoming)
   143  
   144  	return nil
   145  }
   146  
   147  func (ss *sourceSink) Read(bufs [][]byte, sizes []int, destinations []types.NoisePublicKey, offset int) (int, error) {
   148  	packetFn := func(idx int, pkt *stack.PacketBuffer) error {
   149  		defer pkt.DecRef()
   150  
   151  		if ss.debugLogging {
   152  			ss.logger.Debug("Processing netstack packet",
   153  				slog.Uint64("packetHash", hashPacketMetadata(pkt)))
   154  		}
   155  
   156  		// Extract the destination IP address from the packet
   157  		var peerAddr netip.Addr
   158  		switch pkt.NetworkProtocolNumber {
   159  		case header.IPv4ProtocolNumber:
   160  			hdr := header.IPv4(pkt.NetworkHeader().Slice())
   161  			if !hdr.IsValid(pkt.Size()) {
   162  				return fmt.Errorf("invalid IPv4 header")
   163  			}
   164  
   165  			peerAddr = netip.AddrFrom4(hdr.DestinationAddress().As4())
   166  		case header.IPv6ProtocolNumber:
   167  			hdr := header.IPv6(pkt.NetworkHeader().Slice())
   168  			if !hdr.IsValid(pkt.Size()) {
   169  				return fmt.Errorf("invalid IPv6 header")
   170  			}
   171  
   172  			peerAddr = netip.AddrFrom16(hdr.DestinationAddress().As16())
   173  		default:
   174  			return fmt.Errorf("unknown network protocol: %w", syscall.EAFNOSUPPORT)
   175  		}
   176  
   177  		logger := ss.logger.With(slog.String("address", peerAddr.String()))
   178  
   179  		dstPeer, ok := ss.rt.destination(peerAddr)
   180  		if !ok {
   181  			return fmt.Errorf("unknown destination address")
   182  		}
   183  		destinations[idx] = dstPeer.PublicKey()
   184  
   185  		if ss.debugLogging {
   186  			logger.Debug("Sending packet to peer",
   187  				slog.String("destination", destinations[idx].DisplayString()))
   188  		}
   189  
   190  		view := pkt.ToView()
   191  		n, err := view.Read(bufs[idx][offset:])
   192  		view.Release()
   193  		if err != nil {
   194  			return fmt.Errorf("could not read packet: %w", err)
   195  		}
   196  
   197  		sizes[idx] = n
   198  
   199  		return nil
   200  	}
   201  
   202  	// Always block until we have at least one packet.
   203  	var count int
   204  	pkt, ok := <-ss.incoming
   205  	if !ok {
   206  		return 0, net.ErrClosed
   207  	}
   208  
   209  	if err := packetFn(count, pkt); err != nil {
   210  		ss.logger.Warn("Failed to process packet", slog.Any("error", err))
   211  		return count, err
   212  	}
   213  
   214  	count++
   215  
   216  	for count < len(bufs) {
   217  		select {
   218  		case pkt, ok := <-ss.incoming:
   219  			if !ok {
   220  				return count, net.ErrClosed
   221  			}
   222  
   223  			if err := packetFn(count, pkt); err != nil {
   224  				ss.logger.Warn("Failed to process packet", slog.Any("error", err))
   225  				return count, err
   226  			}
   227  
   228  			count++
   229  		default:
   230  			return count, nil
   231  		}
   232  	}
   233  
   234  	return count, nil
   235  }
   236  
   237  func (ss *sourceSink) Write(bufs [][]byte, sources []types.NoisePublicKey, offset int) (int, error) {
   238  	for i, buf := range bufs {
   239  		if len(buf) <= offset {
   240  			continue
   241  		}
   242  
   243  		if ss.debugLogging {
   244  			ss.logger.Debug("Received packet from peer", slog.String("source", sources[i].DisplayString()))
   245  		}
   246  
   247  		// Validate the source address (to prevent spoofing).
   248  		protocolNumber, err := ss.validateSourceAddress(buf[offset:], sources[i])
   249  		if err != nil {
   250  			return i, err
   251  		}
   252  
   253  		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(buf[offset:])})
   254  
   255  		if ss.debugLogging {
   256  			ss.logger.Debug("Injecting inbound packet into netstack")
   257  		}
   258  
   259  		ss.ep.InjectInbound(protocolNumber, pkt)
   260  	}
   261  
   262  	return len(bufs), nil
   263  }
   264  
   265  func (ss *sourceSink) MTU() int {
   266  	return int(ss.ep.MTU())
   267  }
   268  
   269  func (ss *sourceSink) BatchSize() int {
   270  	return conn.IdealBatchSize
   271  }
   272  
   273  func (ss *sourceSink) WriteNotify() {
   274  	pkt := ss.ep.Read()
   275  	if pkt == nil {
   276  		return
   277  	}
   278  
   279  	if ss.debugLogging {
   280  		ss.logger.Debug("Received outbound packet from netstack",
   281  			slog.Uint64("packetHash", hashPacketMetadata(pkt)))
   282  	}
   283  
   284  	ss.incoming <- pkt
   285  }
   286  
   287  func (ss *sourceSink) drain() {
   288  	for {
   289  		select {
   290  		case pkt, ok := <-ss.incoming:
   291  			if !ok {
   292  				return
   293  			}
   294  
   295  			pkt.DecRef()
   296  		default:
   297  			return
   298  		}
   299  	}
   300  }
   301  
   302  func (ss *sourceSink) validateSourceAddress(buf []byte, source types.NoisePublicKey) (tcpip.NetworkProtocolNumber, error) {
   303  	var protocolNumber tcpip.NetworkProtocolNumber
   304  	switch header.IPVersion(buf) {
   305  	case header.IPv4Version:
   306  		protocolNumber = header.IPv4ProtocolNumber
   307  	case header.IPv6Version:
   308  		protocolNumber = header.IPv6ProtocolNumber
   309  	default:
   310  		return 0, fmt.Errorf("unknown IP version: %w", syscall.EAFNOSUPPORT)
   311  	}
   312  
   313  	var addr netip.Addr
   314  	switch protocolNumber {
   315  	case header.IPv4ProtocolNumber:
   316  		hdr := header.IPv4(buf)
   317  		if !hdr.IsValid(len(buf)) {
   318  			return protocolNumber, fmt.Errorf("invalid IPv4 header")
   319  		}
   320  
   321  		addr = netip.AddrFrom4(hdr.SourceAddress().As4())
   322  	case header.IPv6ProtocolNumber:
   323  		hdr := header.IPv6(buf)
   324  		if !hdr.IsValid(len(buf)) {
   325  			return protocolNumber, fmt.Errorf("invalid IPv6 header")
   326  		}
   327  
   328  		addr = netip.AddrFrom16(hdr.SourceAddress().As16())
   329  	default:
   330  		return protocolNumber, fmt.Errorf("unknown network protocol: %w", syscall.EAFNOSUPPORT)
   331  	}
   332  
   333  	expectedSrcPeer, ok := ss.rt.destination(addr)
   334  	if !ok {
   335  		return protocolNumber, fmt.Errorf("unknown source address")
   336  	}
   337  
   338  	if !expectedSrcPeer.PublicKey().Equals(source) {
   339  		return protocolNumber, fmt.Errorf("invalid source address for peer")
   340  	}
   341  
   342  	return protocolNumber, nil
   343  }
   344  
   345  func hashPacketMetadata(pkt *stack.PacketBuffer) uint64 {
   346  	h := fnv.New64a()
   347  	_, _ = h.Write(pkt.NetworkHeader().Slice())
   348  	_, _ = h.Write(pkt.TransportHeader().Slice())
   349  	return h.Sum64()
   350  }