github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/vif/device_windows.go (about) 1 package vif 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "net/netip" 10 "os" 11 "slices" 12 "strings" 13 "time" 14 15 "golang.org/x/sys/windows" 16 "golang.org/x/sys/windows/registry" 17 "golang.zx2c4.com/wireguard/tun" 18 "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" 19 20 "github.com/datawire/dlib/derror" 21 "github.com/datawire/dlib/dlog" 22 "github.com/telepresenceio/telepresence/v2/pkg/client" 23 "github.com/telepresenceio/telepresence/v2/pkg/proc" 24 "github.com/telepresenceio/telepresence/v2/pkg/vif/buffer" 25 ) 26 27 // This nativeDevice will require that wintun.dll is available to the loader. 28 // See: https://www.wintun.net/ for more info. 29 type nativeDevice struct { 30 tun.Device 31 strategy client.GSCStrategy 32 name string 33 dns net.IP 34 interfaceIndex int32 35 searchListAdditions map[string]struct{} 36 } 37 38 func openTun(ctx context.Context) (td *nativeDevice, err error) { 39 defer func() { 40 if r := recover(); r != nil { 41 err = derror.PanicToError(r) 42 dlog.Errorf(ctx, "%+v", err) 43 } 44 }() 45 interfaceFmt := "tel%d" 46 ifaceNumber := 0 47 ifaces, err := net.Interfaces() 48 if err != nil { 49 return nil, fmt.Errorf("failed to get interfaces: %w", err) 50 } 51 for _, iface := range ifaces { 52 dlog.Tracef(ctx, "Found interface %s", iface.Name) 53 // Parse the tel%d number if it's there 54 var num int 55 if _, err := fmt.Sscanf(iface.Name, interfaceFmt, &num); err == nil { 56 if num >= ifaceNumber { 57 ifaceNumber = num + 1 58 } 59 } 60 } 61 interfaceName := fmt.Sprintf(interfaceFmt, ifaceNumber) 62 dlog.Infof(ctx, "Creating interface %s", interfaceName) 63 td = &nativeDevice{ 64 searchListAdditions: make(map[string]struct{}), 65 } 66 if td.Device, err = tun.CreateTUN(interfaceName, 0); err != nil { 67 return nil, fmt.Errorf("failed to create TUN device: %w", err) 68 } 69 if td.name, err = td.Device.Name(); err != nil { 70 return nil, fmt.Errorf("failed to get real name of TUN device: %w", err) 71 } 72 iface, err := td.getLUID().Interface() 73 if err != nil { 74 return nil, fmt.Errorf("failed to get interface for TUN device: %w", err) 75 } 76 td.interfaceIndex = int32(iface.InterfaceIndex) 77 td.strategy = client.GetConfig(ctx).OSSpecific().Network.GlobalDNSSearchConfigStrategy 78 79 return td, nil 80 } 81 82 func (t *nativeDevice) Close() error { 83 // The tun.NativeTun device has a closing mutex which is read locked during 84 // a call to Read(). The read lock prevents a call to Close() to proceed 85 // until Read() actually receives something. To resolve that "deadlock", 86 // we call Close() in one goroutine to wait for the lock and write a bogus 87 // message in another that will be returned by Read(). 88 closeCh := make(chan error) 89 go func() { 90 // first message is just to indicate that this goroutine has started 91 closeCh <- nil 92 closeCh <- t.Device.Close() 93 close(closeCh) 94 }() 95 96 // Not 100%, but we can be fairly sure that Close() is 97 // hanging on the lock, or at least will be by the time 98 // the Read() returns 99 <-closeCh 100 101 // Send something to the TUN device so that the Read 102 // unlocks the NativeTun.closing mutex and let the actual 103 // Close call continue 104 conn, err := net.Dial("udp", net.JoinHostPort(t.dns.String(), "53")) 105 if err == nil { 106 _, _ = conn.Write([]byte("bogus")) 107 } 108 return <-closeCh 109 } 110 111 func (t *nativeDevice) getLUID() winipcfg.LUID { 112 return winipcfg.LUID(t.Device.(*tun.NativeTun).LUID()) 113 } 114 115 func (t *nativeDevice) index() int32 { 116 return t.interfaceIndex 117 } 118 119 func addrFromIP(ip net.IP) netip.Addr { 120 var addr netip.Addr 121 if ip4 := ip.To4(); ip4 != nil { 122 addr = netip.AddrFrom4(*(*[4]byte)(ip4)) 123 } else if ip16 := ip.To16(); ip16 != nil { 124 addr = netip.AddrFrom16(*(*[16]byte)(ip16)) 125 } 126 return addr 127 } 128 129 func prefixFromIPNet(subnet *net.IPNet) netip.Prefix { 130 if subnet == nil { 131 return netip.Prefix{} 132 } 133 ones, _ := subnet.Mask.Size() 134 return netip.PrefixFrom(addrFromIP(subnet.IP), ones) 135 } 136 137 func (t *nativeDevice) addSubnet(_ context.Context, subnet *net.IPNet) error { 138 return t.getLUID().AddIPAddress(prefixFromIPNet(subnet)) 139 } 140 141 func (t *nativeDevice) removeSubnet(_ context.Context, subnet *net.IPNet) error { 142 return t.getLUID().DeleteIPAddress(prefixFromIPNet(subnet)) 143 } 144 145 func (t *nativeDevice) setDNS(ctx context.Context, _ string, server net.IP, searchList []string) (err error) { 146 // This function must not be interrupted by a context cancellation, so we give it a timeout instead. 147 dlog.Debugf(ctx, "SetDNS server: %s, searchList: %v", server, searchList) 148 defer dlog.Debug(ctx, "SetDNS done") 149 150 parentCtx := ctx 151 ctx, cancel := context.WithCancel(context.WithoutCancel(ctx)) 152 defer cancel() 153 154 go func() { 155 <-parentCtx.Done() 156 // Give this function some time to complete its task after the parentCtx is done. Configuring DSN on windows is slow 157 // and we don't want to interrupt it. 158 time.AfterFunc(10*time.Second, cancel) 159 }() 160 161 ipFamily := func(ip net.IP) winipcfg.AddressFamily { 162 f := winipcfg.AddressFamily(windows.AF_INET6) 163 if ip4 := ip.To4(); ip4 != nil { 164 f = windows.AF_INET 165 } 166 return f 167 } 168 family := ipFamily(server) 169 luid := t.getLUID() 170 if t.dns != nil { 171 if oldFamily := ipFamily(t.dns); oldFamily != family { 172 _ = luid.FlushDNS(oldFamily) 173 } 174 } 175 serverStr := server.String() 176 servers16, err := windows.UTF16PtrFromString(serverStr) 177 if err != nil { 178 return err 179 } 180 searchList16, err := windows.UTF16PtrFromString(strings.Join(searchList, ",")) 181 if err != nil { 182 return err 183 } 184 guid, err := luid.GUID() 185 if err != nil { 186 return err 187 } 188 dnsInterfaceSettings := &winipcfg.DnsInterfaceSettings{ 189 Version: winipcfg.DnsInterfaceSettingsVersion1, 190 Flags: winipcfg.DnsInterfaceSettingsFlagNameserver | winipcfg.DnsInterfaceSettingsFlagSearchList, 191 NameServer: servers16, 192 SearchList: searchList16, 193 } 194 if family == windows.AF_INET6 { 195 dnsInterfaceSettings.Flags |= winipcfg.DnsInterfaceSettingsFlagIPv6 196 } 197 if err = winipcfg.SetInterfaceDnsSettings(*guid, dnsInterfaceSettings); err != nil { 198 return err 199 } 200 201 // Unless we also update the global DNS search path, the one for the device doesn't work on some platforms. 202 // This behavior is mainly observed on Windows Server editions. 203 204 // Retrieve the current global search paths so that paths that aren't managed by us can be retained. 205 gss, err := getGlobalSearchList() 206 if err != nil { 207 return err 208 } 209 // Put our new search path in front of other entries. 210 uniq := make(map[string]int, len(searchList)+len(gss)) 211 i := 0 212 for _, gs := range searchList { 213 gs = strings.TrimSuffix(gs, ".") 214 t.searchListAdditions[gs] = struct{}{} 215 if _, ok := uniq[gs]; !ok { 216 uniq[gs] = i 217 i++ 218 } 219 } 220 221 // Include entries that aren't managed by Telepresence. 222 for _, gs := range gss { 223 if _, ok := t.searchListAdditions[gs]; !ok { 224 if _, ok := uniq[gs]; !ok { 225 uniq[gs] = i 226 i++ 227 } 228 } 229 } 230 231 gss = make([]string, len(uniq)) 232 for gs, i := range uniq { 233 gss[i] = gs 234 } 235 t.dns = server 236 if err := t.setGlobalSearchList(ctx, gss); err != nil { 237 return err 238 } 239 240 // Prune the list of additions using the current search path. 241 for gs := range t.searchListAdditions { 242 if !slices.Contains(gss, gs) { 243 delete(t.searchListAdditions, gs) 244 } 245 } 246 return nil 247 } 248 249 func psList(values []string) string { 250 var sb strings.Builder 251 sb.WriteString("@(") 252 for i, gs := range values { 253 if i > 0 { 254 sb.WriteByte(',') 255 } 256 sb.WriteByte('"') 257 sb.WriteString(gs) 258 sb.WriteByte('"') 259 } 260 sb.WriteByte(')') 261 return sb.String() 262 } 263 264 const ( 265 tcpParamKey = `System\CurrentControlSet\Services\Tcpip\Parameters` 266 searchListKey = `SearchList` 267 ) 268 269 func getGlobalSearchList() ([]string, error) { 270 rk, err := registry.OpenKey(registry.LOCAL_MACHINE, tcpParamKey, registry.QUERY_VALUE) 271 if err != nil { 272 if os.IsNotExist(err) { 273 err = nil 274 } 275 return nil, err 276 } 277 defer rk.Close() 278 csv, _, err := rk.GetStringValue(searchListKey) 279 if err != nil { 280 if os.IsNotExist(err) { 281 err = nil 282 } 283 return nil, err 284 } 285 if csv == "" { 286 return nil, nil 287 } 288 return strings.Split(csv, ","), nil 289 } 290 291 func (t *nativeDevice) setGlobalSearchList(ctx context.Context, gss []string) error { 292 var err error 293 if t.strategy == client.GSCAuto || t.strategy == client.GSCRegistry { 294 // Try setting the DNS directly in the registry. It's known to work in some situations where powershell fails. 295 err = t.setRegistryGlobalSearchList(ctx, gss) 296 if err != nil { 297 if t.strategy != client.GSCAuto { 298 dlog.Errorf(ctx, "setting DNS using the registry value failed: %v", err) 299 return err 300 } 301 dlog.Warnf(ctx, `setting DNS by setting the registry value %s\%s directly failed. Will attempt using powershell`, tcpParamKey, searchListKey) 302 t.strategy = client.GSCPowershell 303 } 304 } 305 if t.strategy == client.GSCPowershell { 306 cmd := proc.CommandContext(ctx, "powershell.exe", "-NoProfile", "-NonInteractive", "Set-DnsClientGlobalSetting", "-SuffixSearchList", psList(gss)) 307 if _, err = proc.CaptureErr(cmd); err != nil { 308 dlog.Errorf(ctx, "setting DNS using Powershell failed: %v", err) 309 } 310 } 311 if err == nil { 312 cmd := proc.CommandContext(ctx, "ipconfig.exe", "/flushdns") 313 if _, flushErr := proc.CaptureErr(cmd); flushErr != nil { 314 dlog.Errorf(ctx, "flushing DNS cache failed: %v", flushErr) 315 } 316 } 317 return err 318 } 319 320 func (t *nativeDevice) setRegistryGlobalSearchList(ctx context.Context, gss []string) error { 321 // Try setting the DNS directly in the registry. It's known to work in some situations. 322 rk, _, err := registry.CreateKey(registry.LOCAL_MACHINE, tcpParamKey, registry.SET_VALUE) 323 if err != nil { 324 dlog.Errorf(ctx, `creating/opening registry value %s\%s failed: %v`, tcpParamKey, searchListKey, err) 325 } else { 326 defer rk.Close() 327 rv := strings.Join(gss, ",") 328 dlog.Debugf(ctx, `setting registry value %s\%s to %s`, tcpParamKey, searchListKey, rv) 329 if err = rk.SetStringValue(searchListKey, rv); err != nil { 330 dlog.Errorf(ctx, `setting registry value %s\%s failed: %v`, tcpParamKey, searchListKey, err) 331 } 332 } 333 return err 334 } 335 336 func (t *nativeDevice) setMTU(int) error { 337 return errors.New("not implemented") 338 } 339 340 func (t *nativeDevice) readPacket(into *buffer.Data) (int, error) { 341 sz := make([]int, 1) 342 packetsN, err := t.Device.Read([][]byte{into.Raw()}, sz, 0) 343 if err != nil { 344 return 0, err 345 } 346 if packetsN == 0 { 347 return 0, io.EOF 348 } 349 return sz[0], nil 350 } 351 352 func (t *nativeDevice) writePacket(from *buffer.Data, offset int) (int, error) { 353 packetsN, err := t.Device.Write([][]byte{from.Raw()}, offset) 354 if err != nil { 355 return 0, err 356 } 357 if packetsN == 0 { 358 return 0, io.EOF 359 } 360 return len(from.Raw()), nil 361 }