github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/routing/routing.go (about)

     1  package routing
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"go.opentelemetry.io/otel"
    11  	"go.opentelemetry.io/otel/attribute"
    12  	"go.opentelemetry.io/otel/trace"
    13  
    14  	"github.com/datawire/dlib/dlog"
    15  	"github.com/telepresenceio/telepresence/v2/pkg/tracing"
    16  )
    17  
    18  type Route struct {
    19  	LocalIP   net.IP
    20  	RoutedNet *net.IPNet
    21  	Interface *net.Interface
    22  	Gateway   net.IP
    23  	Default   bool
    24  }
    25  
    26  type Table interface {
    27  	// Add adds a route to the routing table
    28  	Add(ctx context.Context, r *Route) error
    29  	// Remove removes a route from the routing table
    30  	Remove(ctx context.Context, r *Route) error
    31  	// Close closes the routing table
    32  	Close(ctx context.Context) error
    33  }
    34  
    35  func OpenTable(ctx context.Context) (Table, error) {
    36  	return openTable(ctx)
    37  }
    38  
    39  func DefaultRoute(ctx context.Context) (*Route, error) {
    40  	rt, err := GetRoutingTable(ctx)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	for _, r := range rt {
    45  		if r.Default {
    46  			return r, nil
    47  		}
    48  	}
    49  	return nil, errors.New("unable to find a default route")
    50  }
    51  
    52  type rtError string
    53  
    54  func (r rtError) Error() string {
    55  	return string(r)
    56  }
    57  
    58  const (
    59  	errInconsistentRT      = rtError("routing table is inconsistent")
    60  	maxInconsistentRetries = 3
    61  	inconsistentRetryDelay = 50 * time.Millisecond
    62  )
    63  
    64  // GetRoutingTable will return a list of Route objects created from the current routing table.
    65  func GetRoutingTable(ctx context.Context) ([]*Route, error) {
    66  	// The process of creating routes is not atomic. If an intercept is deleted shortly before this function is
    67  	// called, then an interface referenced from a route might no longer exist. When this happens, there will
    68  	// be a short delay followed by a retry.
    69  	for i := 0; i < maxInconsistentRetries; i++ {
    70  		rt, err := getConsistentRoutingTable(ctx)
    71  		if err != errInconsistentRT {
    72  			return rt, err
    73  		}
    74  		time.Sleep(inconsistentRetryDelay)
    75  	}
    76  	return nil, errInconsistentRT
    77  }
    78  
    79  func (r *Route) Routes(ip net.IP) bool {
    80  	return r.RoutedNet.Contains(ip)
    81  }
    82  
    83  func (r *Route) String() string {
    84  	isDefault := " (default)"
    85  	if !r.Default {
    86  		isDefault = ""
    87  	}
    88  	return fmt.Sprintf("%s via %s dev %s, gw %s%s", r.RoutedNet, r.LocalIP, r.Interface.Name, r.Gateway, isDefault)
    89  }
    90  
    91  // AddStatic adds a specific route. This can be used to prevent certain IP addresses
    92  // from being routed to the route's interface.
    93  func (r *Route) AddStatic(ctx context.Context) (err error) {
    94  	dlog.Debugf(ctx, "Adding static route %s", r)
    95  	ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "AddStatic", trace.WithAttributes(attribute.Stringer("tel2.route", r)))
    96  	defer tracing.EndAndRecord(span, err)
    97  	return r.addStatic(ctx)
    98  }
    99  
   100  // RemoveStatic removes a specific route added via AddStatic.
   101  func (r *Route) RemoveStatic(ctx context.Context) (err error) {
   102  	dlog.Debugf(ctx, "Dropping static route %s", r)
   103  	ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "RemoveStaticRoute", trace.WithAttributes(attribute.Stringer("tel2.route", r)))
   104  	defer tracing.EndAndRecord(span, err)
   105  	return r.removeStatic(ctx)
   106  }
   107  
   108  func interfaceLocalIP(iface *net.Interface, ipv4 bool) (net.IP, error) {
   109  	addrs, err := iface.Addrs()
   110  	if err != nil {
   111  		return net.IP{}, fmt.Errorf("unable to get interface addresses for interface %s: %w", iface.Name, err)
   112  	}
   113  	for _, addr := range addrs {
   114  		ip, _, err := net.ParseCIDR(addr.String())
   115  		if err != nil {
   116  			return net.IP{}, fmt.Errorf("unable to parse address %s: %v", addr.String(), err)
   117  		}
   118  		if ip4 := ip.To4(); ip4 != nil {
   119  			if !ipv4 {
   120  				continue
   121  			}
   122  			return ip4, nil
   123  		} else if ipv4 {
   124  			continue
   125  		}
   126  		return ip, nil
   127  	}
   128  	return nil, nil
   129  }