github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/routing/routing_linux.go (about)

     1  package routing
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"regexp"
     8  	"sort"
     9  	"syscall" //nolint:depguard // sys/unix does not have NetlinkRIB
    10  	"unsafe"
    11  
    12  	"github.com/vishvananda/netlink"
    13  
    14  	"github.com/datawire/dlib/dexec"
    15  	"github.com/datawire/dlib/dlog"
    16  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    17  )
    18  
    19  const findInterfaceRegex = `( via (?P<gw>[0-9a-f.:]+))?.* dev (?P<dev>[a-z0-9-]+).* src (?P<src>[0-9a-f.:]+)`
    20  
    21  var (
    22  	findInterfaceRe = regexp.MustCompile(findInterfaceRegex) //nolint:gochecknoglobals // constant
    23  	gwidx           = findInterfaceRe.SubexpIndex("gw")      //nolint:gochecknoglobals // constant
    24  	devIdx          = findInterfaceRe.SubexpIndex("dev")     //nolint:gochecknoglobals // constant
    25  	srcIdx          = findInterfaceRe.SubexpIndex("src")     //nolint:gochecknoglobals // constant
    26  )
    27  
    28  type table struct {
    29  	index int
    30  	rule  *netlink.Rule
    31  }
    32  
    33  type rtmsg struct {
    34  	// Check out https://man7.org/linux/man-pages/man7/rtnetlink.7.html for the definition of rtmsg
    35  	Family   byte // Address family of route
    36  	DstLen   byte // Length of destination
    37  	SrcLen   byte // Length of source
    38  	TOS      byte // TOS filter
    39  	Table    byte // Routing table ID
    40  	Protocol byte // Routing protocol
    41  	Scope    byte
    42  	Type     byte
    43  
    44  	Flags uint32
    45  }
    46  
    47  func getConsistentRoutingTable(_ context.Context) ([]*Route, error) {
    48  	// Most of this logic was adapted from https://github.com/google/gopacket/blob/master/routing/routing.go
    49  	tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
    50  	if err != nil {
    51  		return nil, fmt.Errorf("unable to call netlink for route table: %w", err)
    52  	}
    53  	msgs, err := syscall.ParseNetlinkMessage(tab)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("unable to parse netlink messages: %w", err)
    56  	}
    57  	var routes []*Route
    58  msgLoop:
    59  	for _, msg := range msgs {
    60  		switch msg.Header.Type {
    61  		case syscall.NLMSG_DONE:
    62  			break msgLoop
    63  		case syscall.RTM_NEWROUTE:
    64  			// Based on the gopacket code, we mainly need this rtmsg to grab the size of the mask for the destination network.
    65  			r, err := rowAsRoute((*rtmsg)(unsafe.Pointer(&msg.Data[0])), &msg)
    66  			if err != nil {
    67  				return nil, err
    68  			}
    69  			if r != nil {
    70  				routes = append(routes, r)
    71  			}
    72  		}
    73  	}
    74  	return routes, nil
    75  }
    76  
    77  func rowAsRoute(rt *rtmsg, msg *syscall.NetlinkMessage) (*Route, error) {
    78  	ipv4 := false
    79  	switch rt.Family {
    80  	case syscall.AF_INET:
    81  		ipv4 = true
    82  	case syscall.AF_INET6:
    83  	default:
    84  		return nil, nil
    85  	}
    86  	attrs, err := syscall.ParseNetlinkRouteAttr(msg)
    87  	if err != nil {
    88  		return nil, fmt.Errorf("failed to parse netlink route attributes: %w", err)
    89  	}
    90  
    91  	var gw net.IP
    92  	var dstNet *net.IPNet
    93  	var ifaceIdx int
    94  	for _, attr := range attrs {
    95  		switch attr.Attr.Type {
    96  		case syscall.RTA_DST:
    97  			dstNet = &net.IPNet{
    98  				IP:   attr.Value,
    99  				Mask: net.CIDRMask(int(rt.DstLen), len(attr.Value)*8),
   100  			}
   101  		case syscall.RTA_GATEWAY:
   102  			gw = attr.Value
   103  		case syscall.RTA_OIF:
   104  			ifaceIdx = int(*(*uint32)(unsafe.Pointer(&attr.Value[0])))
   105  		}
   106  	}
   107  	if ifaceIdx < 1 {
   108  		return nil, nil
   109  	}
   110  
   111  	dfltGw := false
   112  	// Default route -- just make the dstNet 0.0.0.0
   113  	if gw != nil && dstNet == nil {
   114  		dfltGw = true
   115  		if ipv4 {
   116  			dstNet = &net.IPNet{
   117  				IP:   net.IP{0, 0, 0, 0},
   118  				Mask: net.CIDRMask(0, 32),
   119  			}
   120  		} else {
   121  			dstNet = &net.IPNet{
   122  				IP:   net.ParseIP("::"),
   123  				Mask: net.CIDRMask(0, 128),
   124  			}
   125  		}
   126  	}
   127  	if dstNet == nil {
   128  		return nil, nil
   129  	}
   130  
   131  	iface, err := net.InterfaceByIndex(ifaceIdx)
   132  	if err != nil {
   133  		// This is not an atomic operation. An intercept may vanish while we're creating this table. When that
   134  		// happens, the best cause of action is to redo the whole process.
   135  		return nil, errInconsistentRT
   136  	}
   137  	if iface.Flags&net.FlagUp == 0 {
   138  		return nil, nil
   139  	}
   140  	srcIP, err := interfaceLocalIP(iface, ipv4)
   141  	if err != nil || srcIP == nil {
   142  		return nil, err
   143  	}
   144  	return &Route{
   145  		LocalIP:   srcIP,
   146  		RoutedNet: dstNet,
   147  		Interface: iface,
   148  		// gw might be nil here, indicating a local route, i.e. directly connected without the packets having to go through a gateway.
   149  		Gateway: gw,
   150  		Default: dfltGw,
   151  	}, nil
   152  }
   153  
   154  func getOsRoute(ctx context.Context, routedNet *net.IPNet) (*Route, error) {
   155  	ip := routedNet.IP
   156  	cmd := dexec.CommandContext(ctx, "ip", "route", "get", ip.String())
   157  	cmd.DisableLogging = true
   158  	out, err := cmd.Output()
   159  	if err != nil {
   160  		return nil, fmt.Errorf("failed to get route for %s: %w", ip, err)
   161  	}
   162  	msg := string(out)
   163  	match := findInterfaceRe.FindStringSubmatch(msg)
   164  	if match == nil {
   165  		return nil, fmt.Errorf("output of ip route did not match %s (output: %s)", findInterfaceRegex, msg)
   166  	}
   167  	var gatewayIP net.IP
   168  	gw := match[gwidx]
   169  	if gw != "" {
   170  		gatewayIP = iputil.Parse(gw)
   171  		if gatewayIP == nil {
   172  			return nil, fmt.Errorf("unable to parse gateway IP %s", gw)
   173  		}
   174  	}
   175  	iface, err := net.InterfaceByName(match[devIdx])
   176  	if err != nil {
   177  		return nil, fmt.Errorf("unable to get interface %s: %w", match[devIdx], err)
   178  	}
   179  	localIP := iputil.Parse(match[srcIdx])
   180  	if localIP == nil {
   181  		return nil, fmt.Errorf("unable to parse local IP %s", match[srcIdx])
   182  	}
   183  	return &Route{
   184  		Gateway:   gatewayIP,
   185  		Interface: iface,
   186  		RoutedNet: routedNet,
   187  		LocalIP:   localIP,
   188  	}, nil
   189  }
   190  
   191  func openTable(ctx context.Context) (Table, error) {
   192  	rules, err := netlink.RuleList(netlink.FAMILY_ALL)
   193  	if err != nil {
   194  		return nil, fmt.Errorf("netlink.RuleList: %w", err)
   195  	}
   196  	// Sort the rules by index ascending to make sure we find an open one
   197  	sort.Slice(rules, func(i, j int) bool {
   198  		return rules[i].Table < rules[j].Table
   199  	})
   200  	index := 775
   201  	priority := 32766 // default initial priority
   202  	for _, rule := range rules {
   203  		dlog.Tracef(ctx, "Found routing rule %+v", rule)
   204  		if rule.Table == 0 || rule.Table == 255 {
   205  			// System rules, ignore
   206  			continue
   207  		}
   208  		if rule.Priority <= priority {
   209  			priority = rule.Priority - 1
   210  		}
   211  		if rule.Table == index {
   212  			// There's already a table with the default index, get a new one
   213  			index++
   214  		}
   215  	}
   216  	dlog.Infof(ctx, "Creating routing table with index %d and priority %d", index, priority)
   217  	rule := netlink.NewRule()
   218  	rule.Table = index
   219  	rule.Priority = priority
   220  	rule.Family = netlink.FAMILY_V4
   221  	if err := netlink.RuleAdd(rule); err != nil {
   222  		return nil, fmt.Errorf("netlink.RuleAdd: %w", err)
   223  	}
   224  	return &table{
   225  		index: index,
   226  		rule:  rule,
   227  	}, nil
   228  }
   229  
   230  func (t *table) routeToNetlink(route *Route) *netlink.Route {
   231  	return &netlink.Route{
   232  		Dst:       route.RoutedNet,
   233  		Table:     t.index,
   234  		LinkIndex: route.Interface.Index,
   235  		Gw:        route.Gateway,
   236  		Src:       route.LocalIP,
   237  	}
   238  }
   239  
   240  func (t *table) Close(ctx context.Context) error {
   241  	return netlink.RuleDel(t.rule)
   242  }
   243  
   244  func (t *table) Add(ctx context.Context, r *Route) error {
   245  	route := t.routeToNetlink(r)
   246  	if err := netlink.RouteAdd(route); err != nil {
   247  		return fmt.Errorf("netlink.RouteAdd: %w", err)
   248  	}
   249  	return nil
   250  }
   251  
   252  func (t *table) Remove(ctx context.Context, r *Route) error {
   253  	route := t.routeToNetlink(r)
   254  	if err := netlink.RouteDel(route); err != nil {
   255  		return fmt.Errorf("netlink.RouteDel: %w", err)
   256  	}
   257  	return nil
   258  }
   259  
   260  func (r *Route) addStatic(ctx context.Context) error {
   261  	return dexec.CommandContext(ctx, "ip", "route", "add", r.RoutedNet.String(), "via", r.Gateway.String(), "dev", r.Interface.Name).Run()
   262  }
   263  
   264  func (r *Route) removeStatic(ctx context.Context) error {
   265  	return dexec.CommandContext(ctx, "ip", "route", "del", r.RoutedNet.String(), "via", r.Gateway.String(), "dev", r.Interface.Name).Run()
   266  }
   267  
   268  func osCompareRoutes(ctx context.Context, osRoute, tableRoute *Route) (bool, error) {
   269  	// On Linux, when we ask about an IP address assigned to the machine, the OS will give us a loopback route
   270  	if osRoute.LocalIP.Equal(osRoute.RoutedNet.IP) && osRoute.Interface.Flags&net.FlagLoopback != 0 {
   271  		addrs, err := tableRoute.Interface.Addrs()
   272  		if err != nil {
   273  			return false, err
   274  		}
   275  		for _, addr := range addrs {
   276  			dlog.Tracef(ctx, "Checking address %s against %s", addr.String(), osRoute.RoutedNet.IP.String())
   277  			if addr.(*net.IPNet).IP.Equal(osRoute.LocalIP) {
   278  				return true, nil
   279  			}
   280  		}
   281  	}
   282  	return false, nil
   283  }