github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/routing/routing_windows.go (about)

     1  package routing
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"regexp"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"golang.org/x/sys/windows"
    15  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    16  
    17  	"github.com/telepresenceio/telepresence/v2/pkg/iputil"
    18  	"github.com/telepresenceio/telepresence/v2/pkg/proc"
    19  	"github.com/telepresenceio/telepresence/v2/pkg/subnet"
    20  )
    21  
    22  type table struct{}
    23  
    24  func rowAsRoute(row *winipcfg.MibIPforwardRow2, localIP net.IP) (*Route, error) {
    25  	dst := row.DestinationPrefix.Prefix()
    26  	if !dst.IsValid() {
    27  		return nil, nil
    28  	}
    29  	gw := row.NextHop.Addr()
    30  	if !gw.IsValid() {
    31  		return nil, nil
    32  	}
    33  	ifaceIdx := int(row.InterfaceIndex)
    34  	iface, err := net.InterfaceByIndex(ifaceIdx)
    35  	if err != nil {
    36  		return nil, errInconsistentRT
    37  	}
    38  	if len(localIP) == 0 {
    39  		localIP, err = interfaceLocalIP(iface, dst.Addr().Is4())
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  	} else if ip4 := localIP.To4(); ip4 != nil {
    44  		localIP = ip4
    45  	}
    46  	if localIP == nil {
    47  		return nil, nil
    48  	}
    49  	ip := dst.Addr().AsSlice()
    50  	var mask net.IPMask
    51  	if dst.Bits() > 0 {
    52  		if dst.Addr().Is4() {
    53  			mask = net.CIDRMask(dst.Bits(), 32)
    54  		} else {
    55  			mask = net.CIDRMask(dst.Bits(), 128)
    56  		}
    57  	}
    58  	routedNet := &net.IPNet{
    59  		IP:   ip,
    60  		Mask: mask,
    61  	}
    62  	return &Route{
    63  		LocalIP:   localIP,
    64  		Gateway:   gw.AsSlice(),
    65  		RoutedNet: routedNet,
    66  		Interface: iface,
    67  		Default:   subnet.IsZeroMask(routedNet),
    68  	}, nil
    69  }
    70  
    71  func getConsistentRoutingTable(ctx context.Context) ([]*Route, error) {
    72  	table, err := winipcfg.GetIPForwardTable2(windows.AF_UNSPEC)
    73  	if err != nil {
    74  		return nil, fmt.Errorf("unable to get routing table: %w", err)
    75  	}
    76  	routes := []*Route{}
    77  	for _, row := range table {
    78  		r, err := rowAsRoute(&row, nil)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  		if r != nil {
    83  			routes = append(routes, r)
    84  		}
    85  	}
    86  	return routes, nil
    87  }
    88  
    89  func getRouteForIP(localIP net.IP) (*Route, error) {
    90  retryInconsistent:
    91  	for i := 0; i < maxInconsistentRetries; i++ {
    92  		table, err := winipcfg.GetIPForwardTable2(windows.AF_UNSPEC)
    93  		if err != nil {
    94  			return nil, fmt.Errorf("unable to get routing table: %w", err)
    95  		}
    96  		for _, row := range table {
    97  			ifaceIdx := int(row.InterfaceIndex)
    98  			if iface, err := net.InterfaceByIndex(ifaceIdx); err == nil && iface.Flags&net.FlagUp == net.FlagUp {
    99  				if addrs, err := iface.Addrs(); err == nil {
   100  					for _, addr := range addrs {
   101  						if ip, _, err := net.ParseCIDR(addr.String()); err == nil && ip.Equal(localIP) {
   102  							r, err := rowAsRoute(&row, ip)
   103  							if err != nil {
   104  								if err == errInconsistentRT {
   105  									time.Sleep(inconsistentRetryDelay)
   106  									continue retryInconsistent
   107  								}
   108  								return nil, err
   109  							}
   110  							if r != nil {
   111  								return r, nil
   112  							}
   113  						}
   114  					}
   115  				}
   116  			}
   117  		}
   118  		break
   119  	}
   120  	return nil, fmt.Errorf("unable to get interface index for IP %s", localIP.String())
   121  }
   122  
   123  func GetRoute(ctx context.Context, routedNet *net.IPNet) (*Route, error) {
   124  	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
   125  	defer cancel()
   126  	ip := routedNet.IP
   127  	cmd := proc.CommandContext(ctx, "pathping", "-n", "-h", "1", "-p", "100", "-w", "100", "-q", "1", ip.String())
   128  	cmd.DisableLogging = true
   129  	stderr := &strings.Builder{}
   130  	cmd.Stderr = stderr
   131  	out, err := cmd.Output()
   132  	if err != nil {
   133  		return nil, fmt.Errorf("unable to run 'pathping %s': %s (%w)", ip, stderr, err)
   134  	}
   135  	var localIP net.IP
   136  	scanner := bufio.NewScanner(bytes.NewReader(out))
   137  	ipLine := regexp.MustCompile(`^\s+0\s+(\S+)\s*$`)
   138  	for scanner.Scan() {
   139  		if match := ipLine.FindStringSubmatch(scanner.Text()); match != nil {
   140  			if localIP = iputil.Parse(match[1]); localIP != nil {
   141  				break
   142  			}
   143  		}
   144  	}
   145  	if localIP == nil {
   146  		return nil, fmt.Errorf("unable to parse local IP from %q", string(out))
   147  	}
   148  	return getRouteForIP(localIP)
   149  }
   150  
   151  func maskToIP(mask net.IPMask) (ip net.IP) {
   152  	ip = make(net.IP, len(mask))
   153  	copy(ip[:], mask)
   154  	return ip
   155  }
   156  
   157  func (r *Route) addStatic(ctx context.Context) error {
   158  	cmd := proc.CommandContext(ctx,
   159  		"route",
   160  		"ADD",
   161  		r.RoutedNet.IP.String(),
   162  		"MASK",
   163  		maskToIP(r.RoutedNet.Mask).String(),
   164  		r.Gateway.String(),
   165  		"IF",
   166  		strconv.Itoa(r.Interface.Index),
   167  	)
   168  	cmd.DisableLogging = true
   169  	out, err := cmd.Output()
   170  	if err != nil {
   171  		return fmt.Errorf("failed to create route %s: %w", r, err)
   172  	}
   173  	if !strings.Contains(string(out), "OK!") {
   174  		return fmt.Errorf("failed to create route %s: %s", r, strings.TrimSpace(string(out)))
   175  	}
   176  	return nil
   177  }
   178  
   179  func (r *Route) removeStatic(ctx context.Context) error {
   180  	cmd := proc.CommandContext(ctx,
   181  		"route",
   182  		"DELETE",
   183  		r.RoutedNet.IP.String(),
   184  	)
   185  	cmd.DisableLogging = true
   186  	err := cmd.Run()
   187  	if err != nil {
   188  		return fmt.Errorf("failed to delete route %s: %w", r, err)
   189  	}
   190  	return nil
   191  }
   192  
   193  func openTable(ctx context.Context) (Table, error) {
   194  	return &table{}, nil
   195  }
   196  
   197  func (t *table) Add(ctx context.Context, r *Route) error {
   198  	return r.AddStatic(ctx)
   199  }
   200  
   201  func (t *table) Remove(ctx context.Context, r *Route) error {
   202  	return r.RemoveStatic(ctx)
   203  }
   204  
   205  func (t *table) Close(ctx context.Context) error {
   206  	return nil
   207  }