github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/vif/device_windows.go (about)

     1  package vif
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/netip"
    10  	"os"
    11  	"slices"
    12  	"strings"
    13  	"time"
    14  
    15  	"golang.org/x/sys/windows"
    16  	"golang.org/x/sys/windows/registry"
    17  	"golang.zx2c4.com/wireguard/tun"
    18  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    19  
    20  	"github.com/datawire/dlib/derror"
    21  	"github.com/datawire/dlib/dlog"
    22  	"github.com/telepresenceio/telepresence/v2/pkg/client"
    23  	"github.com/telepresenceio/telepresence/v2/pkg/proc"
    24  	"github.com/telepresenceio/telepresence/v2/pkg/vif/buffer"
    25  )
    26  
    27  // This nativeDevice will require that wintun.dll is available to the loader.
    28  // See: https://www.wintun.net/ for more info.
    29  type nativeDevice struct {
    30  	tun.Device
    31  	strategy            client.GSCStrategy
    32  	name                string
    33  	dns                 net.IP
    34  	interfaceIndex      int32
    35  	searchListAdditions map[string]struct{}
    36  }
    37  
    38  func openTun(ctx context.Context) (td *nativeDevice, err error) {
    39  	defer func() {
    40  		if r := recover(); r != nil {
    41  			err = derror.PanicToError(r)
    42  			dlog.Errorf(ctx, "%+v", err)
    43  		}
    44  	}()
    45  	interfaceFmt := "tel%d"
    46  	ifaceNumber := 0
    47  	ifaces, err := net.Interfaces()
    48  	if err != nil {
    49  		return nil, fmt.Errorf("failed to get interfaces: %w", err)
    50  	}
    51  	for _, iface := range ifaces {
    52  		dlog.Tracef(ctx, "Found interface %s", iface.Name)
    53  		// Parse the tel%d number if it's there
    54  		var num int
    55  		if _, err := fmt.Sscanf(iface.Name, interfaceFmt, &num); err == nil {
    56  			if num >= ifaceNumber {
    57  				ifaceNumber = num + 1
    58  			}
    59  		}
    60  	}
    61  	interfaceName := fmt.Sprintf(interfaceFmt, ifaceNumber)
    62  	dlog.Infof(ctx, "Creating interface %s", interfaceName)
    63  	td = &nativeDevice{
    64  		searchListAdditions: make(map[string]struct{}),
    65  	}
    66  	if td.Device, err = tun.CreateTUN(interfaceName, 0); err != nil {
    67  		return nil, fmt.Errorf("failed to create TUN device: %w", err)
    68  	}
    69  	if td.name, err = td.Device.Name(); err != nil {
    70  		return nil, fmt.Errorf("failed to get real name of TUN device: %w", err)
    71  	}
    72  	iface, err := td.getLUID().Interface()
    73  	if err != nil {
    74  		return nil, fmt.Errorf("failed to get interface for TUN device: %w", err)
    75  	}
    76  	td.interfaceIndex = int32(iface.InterfaceIndex)
    77  	td.strategy = client.GetConfig(ctx).OSSpecific().Network.GlobalDNSSearchConfigStrategy
    78  
    79  	return td, nil
    80  }
    81  
    82  func (t *nativeDevice) Close() error {
    83  	// The tun.NativeTun device has a closing mutex which is read locked during
    84  	// a call to Read(). The read lock prevents a call to Close() to proceed
    85  	// until Read() actually receives something. To resolve that "deadlock",
    86  	// we call Close() in one goroutine to wait for the lock and write a bogus
    87  	// message in another that will be returned by Read().
    88  	closeCh := make(chan error)
    89  	go func() {
    90  		// first message is just to indicate that this goroutine has started
    91  		closeCh <- nil
    92  		closeCh <- t.Device.Close()
    93  		close(closeCh)
    94  	}()
    95  
    96  	// Not 100%, but we can be fairly sure that Close() is
    97  	// hanging on the lock, or at least will be by the time
    98  	// the Read() returns
    99  	<-closeCh
   100  
   101  	// Send something to the TUN device so that the Read
   102  	// unlocks the NativeTun.closing mutex and let the actual
   103  	// Close call continue
   104  	conn, err := net.Dial("udp", net.JoinHostPort(t.dns.String(), "53"))
   105  	if err == nil {
   106  		_, _ = conn.Write([]byte("bogus"))
   107  	}
   108  	return <-closeCh
   109  }
   110  
   111  func (t *nativeDevice) getLUID() winipcfg.LUID {
   112  	return winipcfg.LUID(t.Device.(*tun.NativeTun).LUID())
   113  }
   114  
   115  func (t *nativeDevice) index() int32 {
   116  	return t.interfaceIndex
   117  }
   118  
   119  func addrFromIP(ip net.IP) netip.Addr {
   120  	var addr netip.Addr
   121  	if ip4 := ip.To4(); ip4 != nil {
   122  		addr = netip.AddrFrom4(*(*[4]byte)(ip4))
   123  	} else if ip16 := ip.To16(); ip16 != nil {
   124  		addr = netip.AddrFrom16(*(*[16]byte)(ip16))
   125  	}
   126  	return addr
   127  }
   128  
   129  func prefixFromIPNet(subnet *net.IPNet) netip.Prefix {
   130  	if subnet == nil {
   131  		return netip.Prefix{}
   132  	}
   133  	ones, _ := subnet.Mask.Size()
   134  	return netip.PrefixFrom(addrFromIP(subnet.IP), ones)
   135  }
   136  
   137  func (t *nativeDevice) addSubnet(_ context.Context, subnet *net.IPNet) error {
   138  	return t.getLUID().AddIPAddress(prefixFromIPNet(subnet))
   139  }
   140  
   141  func (t *nativeDevice) removeSubnet(_ context.Context, subnet *net.IPNet) error {
   142  	return t.getLUID().DeleteIPAddress(prefixFromIPNet(subnet))
   143  }
   144  
   145  func (t *nativeDevice) setDNS(ctx context.Context, _ string, server net.IP, searchList []string) (err error) {
   146  	// This function must not be interrupted by a context cancellation, so we give it a timeout instead.
   147  	dlog.Debugf(ctx, "SetDNS server: %s, searchList: %v", server, searchList)
   148  	defer dlog.Debug(ctx, "SetDNS done")
   149  
   150  	parentCtx := ctx
   151  	ctx, cancel := context.WithCancel(context.WithoutCancel(ctx))
   152  	defer cancel()
   153  
   154  	go func() {
   155  		<-parentCtx.Done()
   156  		// Give this function some time to complete its task after the parentCtx is done. Configuring DSN on windows is slow
   157  		// and we don't want to interrupt it.
   158  		time.AfterFunc(10*time.Second, cancel)
   159  	}()
   160  
   161  	ipFamily := func(ip net.IP) winipcfg.AddressFamily {
   162  		f := winipcfg.AddressFamily(windows.AF_INET6)
   163  		if ip4 := ip.To4(); ip4 != nil {
   164  			f = windows.AF_INET
   165  		}
   166  		return f
   167  	}
   168  	family := ipFamily(server)
   169  	luid := t.getLUID()
   170  	if t.dns != nil {
   171  		if oldFamily := ipFamily(t.dns); oldFamily != family {
   172  			_ = luid.FlushDNS(oldFamily)
   173  		}
   174  	}
   175  	serverStr := server.String()
   176  	servers16, err := windows.UTF16PtrFromString(serverStr)
   177  	if err != nil {
   178  		return err
   179  	}
   180  	searchList16, err := windows.UTF16PtrFromString(strings.Join(searchList, ","))
   181  	if err != nil {
   182  		return err
   183  	}
   184  	guid, err := luid.GUID()
   185  	if err != nil {
   186  		return err
   187  	}
   188  	dnsInterfaceSettings := &winipcfg.DnsInterfaceSettings{
   189  		Version:    winipcfg.DnsInterfaceSettingsVersion1,
   190  		Flags:      winipcfg.DnsInterfaceSettingsFlagNameserver | winipcfg.DnsInterfaceSettingsFlagSearchList,
   191  		NameServer: servers16,
   192  		SearchList: searchList16,
   193  	}
   194  	if family == windows.AF_INET6 {
   195  		dnsInterfaceSettings.Flags |= winipcfg.DnsInterfaceSettingsFlagIPv6
   196  	}
   197  	if err = winipcfg.SetInterfaceDnsSettings(*guid, dnsInterfaceSettings); err != nil {
   198  		return err
   199  	}
   200  
   201  	// Unless we also update the global DNS search path, the one for the device doesn't work on some platforms.
   202  	// This behavior is mainly observed on Windows Server editions.
   203  
   204  	// Retrieve the current global search paths so that paths that aren't managed by us can be retained.
   205  	gss, err := getGlobalSearchList()
   206  	if err != nil {
   207  		return err
   208  	}
   209  	// Put our new search path in front of other entries.
   210  	uniq := make(map[string]int, len(searchList)+len(gss))
   211  	i := 0
   212  	for _, gs := range searchList {
   213  		gs = strings.TrimSuffix(gs, ".")
   214  		t.searchListAdditions[gs] = struct{}{}
   215  		if _, ok := uniq[gs]; !ok {
   216  			uniq[gs] = i
   217  			i++
   218  		}
   219  	}
   220  
   221  	// Include entries that aren't managed by Telepresence.
   222  	for _, gs := range gss {
   223  		if _, ok := t.searchListAdditions[gs]; !ok {
   224  			if _, ok := uniq[gs]; !ok {
   225  				uniq[gs] = i
   226  				i++
   227  			}
   228  		}
   229  	}
   230  
   231  	gss = make([]string, len(uniq))
   232  	for gs, i := range uniq {
   233  		gss[i] = gs
   234  	}
   235  	t.dns = server
   236  	if err := t.setGlobalSearchList(ctx, gss); err != nil {
   237  		return err
   238  	}
   239  
   240  	// Prune the list of additions using the current search path.
   241  	for gs := range t.searchListAdditions {
   242  		if !slices.Contains(gss, gs) {
   243  			delete(t.searchListAdditions, gs)
   244  		}
   245  	}
   246  	return nil
   247  }
   248  
   249  func psList(values []string) string {
   250  	var sb strings.Builder
   251  	sb.WriteString("@(")
   252  	for i, gs := range values {
   253  		if i > 0 {
   254  			sb.WriteByte(',')
   255  		}
   256  		sb.WriteByte('"')
   257  		sb.WriteString(gs)
   258  		sb.WriteByte('"')
   259  	}
   260  	sb.WriteByte(')')
   261  	return sb.String()
   262  }
   263  
   264  const (
   265  	tcpParamKey   = `System\CurrentControlSet\Services\Tcpip\Parameters`
   266  	searchListKey = `SearchList`
   267  )
   268  
   269  func getGlobalSearchList() ([]string, error) {
   270  	rk, err := registry.OpenKey(registry.LOCAL_MACHINE, tcpParamKey, registry.QUERY_VALUE)
   271  	if err != nil {
   272  		if os.IsNotExist(err) {
   273  			err = nil
   274  		}
   275  		return nil, err
   276  	}
   277  	defer rk.Close()
   278  	csv, _, err := rk.GetStringValue(searchListKey)
   279  	if err != nil {
   280  		if os.IsNotExist(err) {
   281  			err = nil
   282  		}
   283  		return nil, err
   284  	}
   285  	if csv == "" {
   286  		return nil, nil
   287  	}
   288  	return strings.Split(csv, ","), nil
   289  }
   290  
   291  func (t *nativeDevice) setGlobalSearchList(ctx context.Context, gss []string) error {
   292  	var err error
   293  	if t.strategy == client.GSCAuto || t.strategy == client.GSCRegistry {
   294  		// Try setting the DNS directly in the registry. It's known to work in some situations where powershell fails.
   295  		err = t.setRegistryGlobalSearchList(ctx, gss)
   296  		if err != nil {
   297  			if t.strategy != client.GSCAuto {
   298  				dlog.Errorf(ctx, "setting DNS using the registry value failed: %v", err)
   299  				return err
   300  			}
   301  			dlog.Warnf(ctx, `setting DNS by setting the registry value %s\%s directly failed. Will attempt using powershell`, tcpParamKey, searchListKey)
   302  			t.strategy = client.GSCPowershell
   303  		}
   304  	}
   305  	if t.strategy == client.GSCPowershell {
   306  		cmd := proc.CommandContext(ctx, "powershell.exe", "-NoProfile", "-NonInteractive", "Set-DnsClientGlobalSetting", "-SuffixSearchList", psList(gss))
   307  		if _, err = proc.CaptureErr(cmd); err != nil {
   308  			dlog.Errorf(ctx, "setting DNS using Powershell failed: %v", err)
   309  		}
   310  	}
   311  	if err == nil {
   312  		cmd := proc.CommandContext(ctx, "ipconfig.exe", "/flushdns")
   313  		if _, flushErr := proc.CaptureErr(cmd); flushErr != nil {
   314  			dlog.Errorf(ctx, "flushing DNS cache failed: %v", flushErr)
   315  		}
   316  	}
   317  	return err
   318  }
   319  
   320  func (t *nativeDevice) setRegistryGlobalSearchList(ctx context.Context, gss []string) error {
   321  	// Try setting the DNS directly in the registry. It's known to work in some situations.
   322  	rk, _, err := registry.CreateKey(registry.LOCAL_MACHINE, tcpParamKey, registry.SET_VALUE)
   323  	if err != nil {
   324  		dlog.Errorf(ctx, `creating/opening registry value %s\%s failed: %v`, tcpParamKey, searchListKey, err)
   325  	} else {
   326  		defer rk.Close()
   327  		rv := strings.Join(gss, ",")
   328  		dlog.Debugf(ctx, `setting registry value %s\%s to %s`, tcpParamKey, searchListKey, rv)
   329  		if err = rk.SetStringValue(searchListKey, rv); err != nil {
   330  			dlog.Errorf(ctx, `setting registry value %s\%s failed: %v`, tcpParamKey, searchListKey, err)
   331  		}
   332  	}
   333  	return err
   334  }
   335  
   336  func (t *nativeDevice) setMTU(int) error {
   337  	return errors.New("not implemented")
   338  }
   339  
   340  func (t *nativeDevice) readPacket(into *buffer.Data) (int, error) {
   341  	sz := make([]int, 1)
   342  	packetsN, err := t.Device.Read([][]byte{into.Raw()}, sz, 0)
   343  	if err != nil {
   344  		return 0, err
   345  	}
   346  	if packetsN == 0 {
   347  		return 0, io.EOF
   348  	}
   349  	return sz[0], nil
   350  }
   351  
   352  func (t *nativeDevice) writePacket(from *buffer.Data, offset int) (int, error) {
   353  	packetsN, err := t.Device.Write([][]byte{from.Raw()}, offset)
   354  	if err != nil {
   355  		return 0, err
   356  	}
   357  	if packetsN == 0 {
   358  		return 0, io.EOF
   359  	}
   360  	return len(from.Raw()), nil
   361  }