github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/routing/routing_linux.go (about) 1 package routing 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "regexp" 8 "sort" 9 "syscall" //nolint:depguard // sys/unix does not have NetlinkRIB 10 "unsafe" 11 12 "github.com/vishvananda/netlink" 13 14 "github.com/datawire/dlib/dexec" 15 "github.com/datawire/dlib/dlog" 16 "github.com/telepresenceio/telepresence/v2/pkg/iputil" 17 ) 18 19 const findInterfaceRegex = `( via (?P<gw>[0-9a-f.:]+))?.* dev (?P<dev>[a-z0-9-]+).* src (?P<src>[0-9a-f.:]+)` 20 21 var ( 22 findInterfaceRe = regexp.MustCompile(findInterfaceRegex) //nolint:gochecknoglobals // constant 23 gwidx = findInterfaceRe.SubexpIndex("gw") //nolint:gochecknoglobals // constant 24 devIdx = findInterfaceRe.SubexpIndex("dev") //nolint:gochecknoglobals // constant 25 srcIdx = findInterfaceRe.SubexpIndex("src") //nolint:gochecknoglobals // constant 26 ) 27 28 type table struct { 29 index int 30 rule *netlink.Rule 31 } 32 33 type rtmsg struct { 34 // Check out https://man7.org/linux/man-pages/man7/rtnetlink.7.html for the definition of rtmsg 35 Family byte // Address family of route 36 DstLen byte // Length of destination 37 SrcLen byte // Length of source 38 TOS byte // TOS filter 39 Table byte // Routing table ID 40 Protocol byte // Routing protocol 41 Scope byte 42 Type byte 43 44 Flags uint32 45 } 46 47 func getConsistentRoutingTable(_ context.Context) ([]*Route, error) { 48 // Most of this logic was adapted from https://github.com/google/gopacket/blob/master/routing/routing.go 49 tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) 50 if err != nil { 51 return nil, fmt.Errorf("unable to call netlink for route table: %w", err) 52 } 53 msgs, err := syscall.ParseNetlinkMessage(tab) 54 if err != nil { 55 return nil, fmt.Errorf("unable to parse netlink messages: %w", err) 56 } 57 var routes []*Route 58 msgLoop: 59 for _, msg := range msgs { 60 switch msg.Header.Type { 61 case syscall.NLMSG_DONE: 62 break msgLoop 63 case syscall.RTM_NEWROUTE: 64 // Based on the gopacket code, we mainly need this rtmsg to grab the size of the mask for the destination network. 65 r, err := rowAsRoute((*rtmsg)(unsafe.Pointer(&msg.Data[0])), &msg) 66 if err != nil { 67 return nil, err 68 } 69 if r != nil { 70 routes = append(routes, r) 71 } 72 } 73 } 74 return routes, nil 75 } 76 77 func rowAsRoute(rt *rtmsg, msg *syscall.NetlinkMessage) (*Route, error) { 78 ipv4 := false 79 switch rt.Family { 80 case syscall.AF_INET: 81 ipv4 = true 82 case syscall.AF_INET6: 83 default: 84 return nil, nil 85 } 86 attrs, err := syscall.ParseNetlinkRouteAttr(msg) 87 if err != nil { 88 return nil, fmt.Errorf("failed to parse netlink route attributes: %w", err) 89 } 90 91 var gw net.IP 92 var dstNet *net.IPNet 93 var ifaceIdx int 94 for _, attr := range attrs { 95 switch attr.Attr.Type { 96 case syscall.RTA_DST: 97 dstNet = &net.IPNet{ 98 IP: attr.Value, 99 Mask: net.CIDRMask(int(rt.DstLen), len(attr.Value)*8), 100 } 101 case syscall.RTA_GATEWAY: 102 gw = attr.Value 103 case syscall.RTA_OIF: 104 ifaceIdx = int(*(*uint32)(unsafe.Pointer(&attr.Value[0]))) 105 } 106 } 107 if ifaceIdx < 1 { 108 return nil, nil 109 } 110 111 dfltGw := false 112 // Default route -- just make the dstNet 0.0.0.0 113 if gw != nil && dstNet == nil { 114 dfltGw = true 115 if ipv4 { 116 dstNet = &net.IPNet{ 117 IP: net.IP{0, 0, 0, 0}, 118 Mask: net.CIDRMask(0, 32), 119 } 120 } else { 121 dstNet = &net.IPNet{ 122 IP: net.ParseIP("::"), 123 Mask: net.CIDRMask(0, 128), 124 } 125 } 126 } 127 if dstNet == nil { 128 return nil, nil 129 } 130 131 iface, err := net.InterfaceByIndex(ifaceIdx) 132 if err != nil { 133 // This is not an atomic operation. An intercept may vanish while we're creating this table. When that 134 // happens, the best cause of action is to redo the whole process. 135 return nil, errInconsistentRT 136 } 137 if iface.Flags&net.FlagUp == 0 { 138 return nil, nil 139 } 140 srcIP, err := interfaceLocalIP(iface, ipv4) 141 if err != nil || srcIP == nil { 142 return nil, err 143 } 144 return &Route{ 145 LocalIP: srcIP, 146 RoutedNet: dstNet, 147 Interface: iface, 148 // gw might be nil here, indicating a local route, i.e. directly connected without the packets having to go through a gateway. 149 Gateway: gw, 150 Default: dfltGw, 151 }, nil 152 } 153 154 func getOsRoute(ctx context.Context, routedNet *net.IPNet) (*Route, error) { 155 ip := routedNet.IP 156 cmd := dexec.CommandContext(ctx, "ip", "route", "get", ip.String()) 157 cmd.DisableLogging = true 158 out, err := cmd.Output() 159 if err != nil { 160 return nil, fmt.Errorf("failed to get route for %s: %w", ip, err) 161 } 162 msg := string(out) 163 match := findInterfaceRe.FindStringSubmatch(msg) 164 if match == nil { 165 return nil, fmt.Errorf("output of ip route did not match %s (output: %s)", findInterfaceRegex, msg) 166 } 167 var gatewayIP net.IP 168 gw := match[gwidx] 169 if gw != "" { 170 gatewayIP = iputil.Parse(gw) 171 if gatewayIP == nil { 172 return nil, fmt.Errorf("unable to parse gateway IP %s", gw) 173 } 174 } 175 iface, err := net.InterfaceByName(match[devIdx]) 176 if err != nil { 177 return nil, fmt.Errorf("unable to get interface %s: %w", match[devIdx], err) 178 } 179 localIP := iputil.Parse(match[srcIdx]) 180 if localIP == nil { 181 return nil, fmt.Errorf("unable to parse local IP %s", match[srcIdx]) 182 } 183 return &Route{ 184 Gateway: gatewayIP, 185 Interface: iface, 186 RoutedNet: routedNet, 187 LocalIP: localIP, 188 }, nil 189 } 190 191 func openTable(ctx context.Context) (Table, error) { 192 rules, err := netlink.RuleList(netlink.FAMILY_ALL) 193 if err != nil { 194 return nil, fmt.Errorf("netlink.RuleList: %w", err) 195 } 196 // Sort the rules by index ascending to make sure we find an open one 197 sort.Slice(rules, func(i, j int) bool { 198 return rules[i].Table < rules[j].Table 199 }) 200 index := 775 201 priority := 32766 // default initial priority 202 for _, rule := range rules { 203 dlog.Tracef(ctx, "Found routing rule %+v", rule) 204 if rule.Table == 0 || rule.Table == 255 { 205 // System rules, ignore 206 continue 207 } 208 if rule.Priority <= priority { 209 priority = rule.Priority - 1 210 } 211 if rule.Table == index { 212 // There's already a table with the default index, get a new one 213 index++ 214 } 215 } 216 dlog.Infof(ctx, "Creating routing table with index %d and priority %d", index, priority) 217 rule := netlink.NewRule() 218 rule.Table = index 219 rule.Priority = priority 220 rule.Family = netlink.FAMILY_V4 221 if err := netlink.RuleAdd(rule); err != nil { 222 return nil, fmt.Errorf("netlink.RuleAdd: %w", err) 223 } 224 return &table{ 225 index: index, 226 rule: rule, 227 }, nil 228 } 229 230 func (t *table) routeToNetlink(route *Route) *netlink.Route { 231 return &netlink.Route{ 232 Dst: route.RoutedNet, 233 Table: t.index, 234 LinkIndex: route.Interface.Index, 235 Gw: route.Gateway, 236 Src: route.LocalIP, 237 } 238 } 239 240 func (t *table) Close(ctx context.Context) error { 241 return netlink.RuleDel(t.rule) 242 } 243 244 func (t *table) Add(ctx context.Context, r *Route) error { 245 route := t.routeToNetlink(r) 246 if err := netlink.RouteAdd(route); err != nil { 247 return fmt.Errorf("netlink.RouteAdd: %w", err) 248 } 249 return nil 250 } 251 252 func (t *table) Remove(ctx context.Context, r *Route) error { 253 route := t.routeToNetlink(r) 254 if err := netlink.RouteDel(route); err != nil { 255 return fmt.Errorf("netlink.RouteDel: %w", err) 256 } 257 return nil 258 } 259 260 func (r *Route) addStatic(ctx context.Context) error { 261 return dexec.CommandContext(ctx, "ip", "route", "add", r.RoutedNet.String(), "via", r.Gateway.String(), "dev", r.Interface.Name).Run() 262 } 263 264 func (r *Route) removeStatic(ctx context.Context) error { 265 return dexec.CommandContext(ctx, "ip", "route", "del", r.RoutedNet.String(), "via", r.Gateway.String(), "dev", r.Interface.Name).Run() 266 } 267 268 func osCompareRoutes(ctx context.Context, osRoute, tableRoute *Route) (bool, error) { 269 // On Linux, when we ask about an IP address assigned to the machine, the OS will give us a loopback route 270 if osRoute.LocalIP.Equal(osRoute.RoutedNet.IP) && osRoute.Interface.Flags&net.FlagLoopback != 0 { 271 addrs, err := tableRoute.Interface.Addrs() 272 if err != nil { 273 return false, err 274 } 275 for _, addr := range addrs { 276 dlog.Tracef(ctx, "Checking address %s against %s", addr.String(), osRoute.RoutedNet.IP.String()) 277 if addr.(*net.IPNet).IP.Equal(osRoute.LocalIP) { 278 return true, nil 279 } 280 } 281 } 282 return false, nil 283 }