github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/routing_table.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   */
     9  
    10  package noisysockets
    11  
    12  import (
    13  	"fmt"
    14  	"log/slog"
    15  	"net/netip"
    16  	"sync"
    17  
    18  	"github.com/noisysockets/noisysockets/internal/uint128"
    19  	"github.com/noisysockets/noisysockets/internal/util"
    20  	"github.com/rdleal/intervalst/interval"
    21  )
    22  
    23  type peerWithPrefix struct {
    24  	*Peer
    25  	prefix netip.Prefix
    26  }
    27  
    28  type intervalRange struct {
    29  	withPrefix *peerWithPrefix
    30  	start      uint128.Uint128
    31  	end        uint128.Uint128
    32  }
    33  
    34  // routingTable is a routing table that maps ip addresses to peers.
    35  type routingTable struct {
    36  	mu              sync.RWMutex
    37  	logger          *slog.Logger
    38  	intervalsByPeer map[*Peer][]intervalRange
    39  	destinations    *interval.SearchTree[*peerWithPrefix, uint128.Uint128]
    40  }
    41  
    42  func newRoutingTable(logger *slog.Logger) *routingTable {
    43  	return &routingTable{
    44  		logger:          logger,
    45  		intervalsByPeer: make(map[*Peer][]intervalRange),
    46  		destinations: interval.NewSearchTreeWithOptions[*peerWithPrefix](func(k1, k2 uint128.Uint128) int {
    47  			return k1.Cmp(k2)
    48  		}, interval.TreeWithIntervalPoint()),
    49  	}
    50  }
    51  
    52  func (rt *routingTable) destination(addr netip.Addr) (*Peer, bool) {
    53  	rt.mu.RLock()
    54  	defer rt.mu.RUnlock()
    55  
    56  	addrInt := addrToUint128(addr)
    57  	destinations, ok := rt.destinations.AllIntersections(addrInt, addrInt)
    58  	if !ok {
    59  		return nil, false
    60  	}
    61  
    62  	// Look for the longest prefix length.
    63  	var destination *peerWithPrefix
    64  	for _, p := range destinations {
    65  		if destination == nil || p.prefix.Bits() > destination.prefix.Bits() {
    66  			destination = p
    67  		}
    68  	}
    69  
    70  	return destination.Peer, true
    71  }
    72  
    73  // update upserts the routing table with the peer's information.
    74  func (rt *routingTable) update(p *Peer) error {
    75  	rt.mu.Lock()
    76  	defer rt.mu.Unlock()
    77  
    78  	// Compute the intervals for the peer.
    79  	intervals, err := rt.intervals(p)
    80  	if err != nil {
    81  		return fmt.Errorf("failed to compute intervals for peer %s: %w", p.Name(), err)
    82  	}
    83  
    84  	// Update the intervals for the peer.
    85  	for _, interval := range intervals {
    86  		found := false
    87  		for _, existingInterval := range rt.intervalsByPeer[p] {
    88  			if existingInterval.start.Cmp(interval.start) == 0 && existingInterval.end.Cmp(interval.end) == 0 {
    89  				found = true
    90  				break
    91  			}
    92  		}
    93  
    94  		if !found {
    95  			rt.logger.Debug("Adding interval to routing table",
    96  				slog.String("peer", p.Name()),
    97  				slog.String("prefix", interval.withPrefix.prefix.String()))
    98  
    99  			if err := rt.destinations.Insert(interval.start, interval.end, interval.withPrefix); err != nil {
   100  				return fmt.Errorf("failed to add interval to routing table: %w", err)
   101  			}
   102  		}
   103  	}
   104  
   105  	// Remove existing intervals that are not needed anymore.
   106  	for _, interval := range rt.intervalsByPeer[p] {
   107  		found := false
   108  		for _, newInterval := range intervals {
   109  			if interval.start.Cmp(newInterval.start) == 0 && interval.end.Cmp(newInterval.end) == 0 {
   110  				found = true
   111  				break
   112  			}
   113  		}
   114  
   115  		if !found {
   116  			rt.logger.Debug("Removing prefix from routing table",
   117  				slog.String("peer", p.Name()),
   118  				slog.String("prefix", interval.withPrefix.prefix.String()))
   119  
   120  			if err := rt.destinations.Delete(interval.start, interval.end); err != nil {
   121  				return fmt.Errorf("failed to remove interval from routing table: %w", err)
   122  			}
   123  		}
   124  	}
   125  
   126  	rt.intervalsByPeer[p] = intervals
   127  
   128  	return nil
   129  }
   130  
   131  // remove removes the peer from the routing table.
   132  func (rt *routingTable) remove(p *Peer) error {
   133  	rt.mu.Lock()
   134  	defer rt.mu.Unlock()
   135  
   136  	for _, interval := range rt.intervalsByPeer[p] {
   137  		rt.logger.Debug("Removing prefix from routing table",
   138  			slog.String("peer", p.Name()),
   139  			slog.String("prefix", interval.withPrefix.prefix.String()))
   140  
   141  		if err := rt.destinations.Delete(interval.start, interval.end); err != nil {
   142  			return fmt.Errorf("failed to remove interval from routing table: %w", err)
   143  		}
   144  	}
   145  
   146  	delete(rt.intervalsByPeer, p)
   147  
   148  	return nil
   149  }
   150  
   151  func (rt *routingTable) intervals(p *Peer) ([]intervalRange, error) {
   152  	var intervals []intervalRange
   153  
   154  	// Add all the peer's addresses.
   155  	for _, addr := range p.Addresses() {
   156  		addrInt := addrToUint128(addr)
   157  
   158  		var prefix netip.Prefix
   159  		if addr.Is4() {
   160  			prefix = netip.MustParsePrefix(addr.String() + "/32")
   161  		} else {
   162  			prefix = netip.MustParsePrefix(addr.String() + "/128")
   163  		}
   164  
   165  		rt.logger.Debug("Adding address to routing table",
   166  			slog.String("peer", p.Name()),
   167  			slog.String("address", addr.String()))
   168  
   169  		intervals = append(intervals, intervalRange{
   170  			withPrefix: &peerWithPrefix{
   171  				Peer:   p,
   172  				prefix: prefix,
   173  			},
   174  			start: addrInt,
   175  			end:   addrInt,
   176  		})
   177  	}
   178  
   179  	// Add any registered routes.
   180  	for _, prefix := range p.DestinationForPrefixes() {
   181  		startAddr, endAddr, err := util.PrefixRange(prefix)
   182  		if err != nil {
   183  			return nil, fmt.Errorf("failed to calculate range for prefix %s: %w", prefix, err)
   184  		}
   185  
   186  		rt.logger.Debug("Adding prefix to routing table",
   187  			slog.String("peer", p.Name()),
   188  			slog.String("prefix", prefix.String()))
   189  
   190  		startAddrInt := addrToUint128(startAddr)
   191  		endAddrInt := addrToUint128(endAddr)
   192  
   193  		intervals = append(intervals, intervalRange{
   194  			withPrefix: &peerWithPrefix{
   195  				Peer:   p,
   196  				prefix: prefix,
   197  			},
   198  			start: startAddrInt,
   199  			end:   endAddrInt,
   200  		})
   201  	}
   202  
   203  	return intervals, nil
   204  }
   205  
   206  func addrToUint128(addr netip.Addr) uint128.Uint128 {
   207  	if addr.Is4() {
   208  		as16 := addr.As16()
   209  		return uint128.FromBytes(as16[:]).ReverseBytes()
   210  	}
   211  
   212  	return uint128.FromBytes(addr.AsSlice()).ReverseBytes()
   213  }