golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/pitfalls.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tunnel
     7  
     8  import (
     9  	"log"
    10  	"net/netip"
    11  	"strings"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  	"golang.org/x/sys/windows/svc/mgr"
    16  	"golang.zx2c4.com/wireguard/windows/conf"
    17  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    18  )
    19  
    20  func evaluateStaticPitfalls() {
    21  	go func() {
    22  		pitfallDnsCacheDisabled()
    23  		pitfallVirtioNetworkDriver()
    24  	}()
    25  }
    26  
    27  func evaluateDynamicPitfalls(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) {
    28  	go func() {
    29  		pitfallWeakHostSend(family, conf, luid)
    30  	}()
    31  }
    32  
    33  func pitfallDnsCacheDisabled() {
    34  	scm, err := mgr.Connect()
    35  	if err != nil {
    36  		return
    37  	}
    38  	defer scm.Disconnect()
    39  	svc := mgr.Service{Name: "dnscache"}
    40  	svc.Handle, err = windows.OpenService(scm.Handle, windows.StringToUTF16Ptr(svc.Name), windows.SERVICE_QUERY_CONFIG)
    41  	if err != nil {
    42  		return
    43  	}
    44  	defer svc.Close()
    45  	cfg, err := svc.Config()
    46  	if err != nil {
    47  		return
    48  	}
    49  	if cfg.StartType != mgr.StartDisabled {
    50  		return
    51  	}
    52  
    53  	log.Printf("Warning: the %q (dnscache) service is disabled; please re-enable it", cfg.DisplayName)
    54  }
    55  
    56  func pitfallVirtioNetworkDriver() {
    57  	var modules []windows.RTL_PROCESS_MODULE_INFORMATION
    58  	for bufferSize := uint32(128 * 1024); ; {
    59  		moduleBuffer := make([]byte, bufferSize)
    60  		err := windows.NtQuerySystemInformation(windows.SystemModuleInformation, unsafe.Pointer(&moduleBuffer[0]), bufferSize, &bufferSize)
    61  		switch err {
    62  		case windows.STATUS_INFO_LENGTH_MISMATCH:
    63  			continue
    64  		case nil:
    65  			break
    66  		default:
    67  			return
    68  		}
    69  		mods := (*windows.RTL_PROCESS_MODULES)(unsafe.Pointer(&moduleBuffer[0]))
    70  		modules = unsafe.Slice(&mods.Modules[0], mods.NumberOfModules)
    71  		break
    72  	}
    73  	for i := range modules {
    74  		if !strings.EqualFold(windows.ByteSliceToString(modules[i].FullPathName[modules[i].OffsetToFileName:]), "netkvm.sys") {
    75  			continue
    76  		}
    77  		driverPath := `\\?\GLOBALROOT` + windows.ByteSliceToString(modules[i].FullPathName[:])
    78  		var zero windows.Handle
    79  		infoSize, err := windows.GetFileVersionInfoSize(driverPath, &zero)
    80  		if err != nil {
    81  			return
    82  		}
    83  		versionInfo := make([]byte, infoSize)
    84  		err = windows.GetFileVersionInfo(driverPath, 0, infoSize, unsafe.Pointer(&versionInfo[0]))
    85  		if err != nil {
    86  			return
    87  		}
    88  		var fixedInfo *windows.VS_FIXEDFILEINFO
    89  		fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
    90  		err = windows.VerQueryValue(unsafe.Pointer(&versionInfo[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
    91  		if err != nil {
    92  			return
    93  		}
    94  		const minimumPlausibleVersion = 40 << 48
    95  		const minimumGoodVersion = (100 << 48) | (85 << 32) | (104 << 16) | (20800 << 0)
    96  		version := (uint64(fixedInfo.FileVersionMS) << 32) | uint64(fixedInfo.FileVersionLS)
    97  		if version >= minimumGoodVersion || version < minimumPlausibleVersion {
    98  			return
    99  		}
   100  		log.Println("Warning: the VirtIO network driver (NetKVM) is out of date and may cause known problems; please update to v100.85.104.20800 or later")
   101  		return
   102  	}
   103  }
   104  
   105  func pitfallWeakHostSend(family winipcfg.AddressFamily, conf *conf.Config, ourLUID winipcfg.LUID) {
   106  	routingTable, err := winipcfg.GetIPForwardTable2(family)
   107  	if err != nil {
   108  		return
   109  	}
   110  	type endpointRoute struct {
   111  		addr         netip.Addr
   112  		name         string
   113  		lowestMetric uint32
   114  		highestCIDR  uint8
   115  		weakHostSend bool
   116  		finalIsOurs  bool
   117  	}
   118  	endpoints := make([]endpointRoute, 0, len(conf.Peers))
   119  	for _, peer := range conf.Peers {
   120  		addr, err := netip.ParseAddr(peer.Endpoint.Host)
   121  		if err != nil || (addr.Is4() && family != windows.AF_INET) || (addr.Is6() && family != windows.AF_INET6) {
   122  			continue
   123  		}
   124  		endpoints = append(endpoints, endpointRoute{addr: addr, lowestMetric: ^uint32(0)})
   125  	}
   126  	for i := range routingTable {
   127  		var (
   128  			ifrow    *winipcfg.MibIfRow2
   129  			ifacerow *winipcfg.MibIPInterfaceRow
   130  			metric   uint32
   131  		)
   132  		for j := range endpoints {
   133  			r, e := &routingTable[i], &endpoints[j]
   134  			if r.DestinationPrefix.PrefixLength < e.highestCIDR {
   135  				continue
   136  			}
   137  			if !r.DestinationPrefix.Prefix().Contains(e.addr) {
   138  				continue
   139  			}
   140  			if ifrow == nil {
   141  				ifrow, err = r.InterfaceLUID.Interface()
   142  				if err != nil {
   143  					continue
   144  				}
   145  			}
   146  			if ifrow.OperStatus != winipcfg.IfOperStatusUp {
   147  				continue
   148  			}
   149  			if ifacerow == nil {
   150  				ifacerow, err = r.InterfaceLUID.IPInterface(family)
   151  				if err != nil {
   152  					continue
   153  				}
   154  				metric = r.Metric + ifacerow.Metric
   155  			}
   156  			if r.DestinationPrefix.PrefixLength == e.highestCIDR && metric > e.lowestMetric {
   157  				continue
   158  			}
   159  			e.lowestMetric = metric
   160  			e.highestCIDR = r.DestinationPrefix.PrefixLength
   161  			e.finalIsOurs = r.InterfaceLUID == ourLUID
   162  			if !e.finalIsOurs {
   163  				e.name = ifrow.Alias()
   164  				e.weakHostSend = ifacerow.ForwardingEnabled || ifacerow.WeakHostSend
   165  			}
   166  		}
   167  	}
   168  	problematicInterfaces := make(map[string]bool, len(endpoints))
   169  	for _, e := range endpoints {
   170  		if e.weakHostSend && e.finalIsOurs {
   171  			problematicInterfaces[e.name] = true
   172  		}
   173  	}
   174  	for iface := range problematicInterfaces {
   175  		log.Printf("Warning: the %q interface has Forwarding/WeakHostSend enabled, which will cause routing loops", iface)
   176  	}
   177  }