github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/routing_table.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 */ 9 10 package noisysockets 11 12 import ( 13 "fmt" 14 "log/slog" 15 "net/netip" 16 "sync" 17 18 "github.com/noisysockets/noisysockets/internal/uint128" 19 "github.com/noisysockets/noisysockets/internal/util" 20 "github.com/rdleal/intervalst/interval" 21 ) 22 23 type peerWithPrefix struct { 24 *Peer 25 prefix netip.Prefix 26 } 27 28 type intervalRange struct { 29 withPrefix *peerWithPrefix 30 start uint128.Uint128 31 end uint128.Uint128 32 } 33 34 // routingTable is a routing table that maps ip addresses to peers. 35 type routingTable struct { 36 mu sync.RWMutex 37 logger *slog.Logger 38 intervalsByPeer map[*Peer][]intervalRange 39 destinations *interval.SearchTree[*peerWithPrefix, uint128.Uint128] 40 } 41 42 func newRoutingTable(logger *slog.Logger) *routingTable { 43 return &routingTable{ 44 logger: logger, 45 intervalsByPeer: make(map[*Peer][]intervalRange), 46 destinations: interval.NewSearchTreeWithOptions[*peerWithPrefix](func(k1, k2 uint128.Uint128) int { 47 return k1.Cmp(k2) 48 }, interval.TreeWithIntervalPoint()), 49 } 50 } 51 52 func (rt *routingTable) destination(addr netip.Addr) (*Peer, bool) { 53 rt.mu.RLock() 54 defer rt.mu.RUnlock() 55 56 addrInt := addrToUint128(addr) 57 destinations, ok := rt.destinations.AllIntersections(addrInt, addrInt) 58 if !ok { 59 return nil, false 60 } 61 62 // Look for the longest prefix length. 63 var destination *peerWithPrefix 64 for _, p := range destinations { 65 if destination == nil || p.prefix.Bits() > destination.prefix.Bits() { 66 destination = p 67 } 68 } 69 70 return destination.Peer, true 71 } 72 73 // update upserts the routing table with the peer's information. 74 func (rt *routingTable) update(p *Peer) error { 75 rt.mu.Lock() 76 defer rt.mu.Unlock() 77 78 // Compute the intervals for the peer. 79 intervals, err := rt.intervals(p) 80 if err != nil { 81 return fmt.Errorf("failed to compute intervals for peer %s: %w", p.Name(), err) 82 } 83 84 // Update the intervals for the peer. 85 for _, interval := range intervals { 86 found := false 87 for _, existingInterval := range rt.intervalsByPeer[p] { 88 if existingInterval.start.Cmp(interval.start) == 0 && existingInterval.end.Cmp(interval.end) == 0 { 89 found = true 90 break 91 } 92 } 93 94 if !found { 95 rt.logger.Debug("Adding interval to routing table", 96 slog.String("peer", p.Name()), 97 slog.String("prefix", interval.withPrefix.prefix.String())) 98 99 if err := rt.destinations.Insert(interval.start, interval.end, interval.withPrefix); err != nil { 100 return fmt.Errorf("failed to add interval to routing table: %w", err) 101 } 102 } 103 } 104 105 // Remove existing intervals that are not needed anymore. 106 for _, interval := range rt.intervalsByPeer[p] { 107 found := false 108 for _, newInterval := range intervals { 109 if interval.start.Cmp(newInterval.start) == 0 && interval.end.Cmp(newInterval.end) == 0 { 110 found = true 111 break 112 } 113 } 114 115 if !found { 116 rt.logger.Debug("Removing prefix from routing table", 117 slog.String("peer", p.Name()), 118 slog.String("prefix", interval.withPrefix.prefix.String())) 119 120 if err := rt.destinations.Delete(interval.start, interval.end); err != nil { 121 return fmt.Errorf("failed to remove interval from routing table: %w", err) 122 } 123 } 124 } 125 126 rt.intervalsByPeer[p] = intervals 127 128 return nil 129 } 130 131 // remove removes the peer from the routing table. 132 func (rt *routingTable) remove(p *Peer) error { 133 rt.mu.Lock() 134 defer rt.mu.Unlock() 135 136 for _, interval := range rt.intervalsByPeer[p] { 137 rt.logger.Debug("Removing prefix from routing table", 138 slog.String("peer", p.Name()), 139 slog.String("prefix", interval.withPrefix.prefix.String())) 140 141 if err := rt.destinations.Delete(interval.start, interval.end); err != nil { 142 return fmt.Errorf("failed to remove interval from routing table: %w", err) 143 } 144 } 145 146 delete(rt.intervalsByPeer, p) 147 148 return nil 149 } 150 151 func (rt *routingTable) intervals(p *Peer) ([]intervalRange, error) { 152 var intervals []intervalRange 153 154 // Add all the peer's addresses. 155 for _, addr := range p.Addresses() { 156 addrInt := addrToUint128(addr) 157 158 var prefix netip.Prefix 159 if addr.Is4() { 160 prefix = netip.MustParsePrefix(addr.String() + "/32") 161 } else { 162 prefix = netip.MustParsePrefix(addr.String() + "/128") 163 } 164 165 rt.logger.Debug("Adding address to routing table", 166 slog.String("peer", p.Name()), 167 slog.String("address", addr.String())) 168 169 intervals = append(intervals, intervalRange{ 170 withPrefix: &peerWithPrefix{ 171 Peer: p, 172 prefix: prefix, 173 }, 174 start: addrInt, 175 end: addrInt, 176 }) 177 } 178 179 // Add any registered routes. 180 for _, prefix := range p.DestinationForPrefixes() { 181 startAddr, endAddr, err := util.PrefixRange(prefix) 182 if err != nil { 183 return nil, fmt.Errorf("failed to calculate range for prefix %s: %w", prefix, err) 184 } 185 186 rt.logger.Debug("Adding prefix to routing table", 187 slog.String("peer", p.Name()), 188 slog.String("prefix", prefix.String())) 189 190 startAddrInt := addrToUint128(startAddr) 191 endAddrInt := addrToUint128(endAddr) 192 193 intervals = append(intervals, intervalRange{ 194 withPrefix: &peerWithPrefix{ 195 Peer: p, 196 prefix: prefix, 197 }, 198 start: startAddrInt, 199 end: endAddrInt, 200 }) 201 } 202 203 return intervals, nil 204 } 205 206 func addrToUint128(addr netip.Addr) uint128.Uint128 { 207 if addr.Is4() { 208 as16 := addr.As16() 209 return uint128.FromBytes(as16[:]).ReverseBytes() 210 } 211 212 return uint128.FromBytes(addr.AsSlice()).ReverseBytes() 213 }