golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/addressconfig.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tunnel
     7  
     8  import (
     9  	"fmt"
    10  	"log"
    11  	"net/netip"
    12  	"time"
    13  
    14  	"golang.org/x/sys/windows"
    15  	"golang.zx2c4.com/wireguard/windows/conf"
    16  	"golang.zx2c4.com/wireguard/windows/services"
    17  	"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
    18  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    19  )
    20  
    21  func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) {
    22  	if len(addresses) == 0 {
    23  		return
    24  	}
    25  	addrHash := make(map[netip.Addr]bool, len(addresses))
    26  	for i := range addresses {
    27  		addrHash[addresses[i].Addr()] = true
    28  	}
    29  	interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
    30  	if err != nil {
    31  		return
    32  	}
    33  	for _, iface := range interfaces {
    34  		if iface.OperStatus == winipcfg.IfOperStatusUp {
    35  			continue
    36  		}
    37  		for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
    38  			if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] {
    39  				prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
    40  				log.Printf("Cleaning up stale address %s from interface ā€˜%s’", prefix.String(), iface.FriendlyName())
    41  				iface.LUID.DeleteIPAddress(prefix)
    42  			}
    43  		}
    44  	}
    45  }
    46  
    47  func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error {
    48  	retryOnFailure := services.StartedAtBoot()
    49  	tryTimes := 0
    50  startOver:
    51  	var err error
    52  	if tryTimes > 0 {
    53  		log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
    54  		time.Sleep(time.Second)
    55  		retryOnFailure = retryOnFailure && tryTimes < 15
    56  	}
    57  	tryTimes++
    58  
    59  	estimatedRouteCount := 0
    60  	for _, peer := range conf.Peers {
    61  		estimatedRouteCount += len(peer.AllowedIPs)
    62  	}
    63  	routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount)
    64  
    65  	foundDefault4 := false
    66  	foundDefault6 := false
    67  	for _, peer := range conf.Peers {
    68  		for _, allowedip := range peer.AllowedIPs {
    69  			route := winipcfg.RouteData{
    70  				Destination: allowedip.Masked(),
    71  				Metric:      0,
    72  			}
    73  			if allowedip.Addr().Is4() {
    74  				if allowedip.Bits() == 0 {
    75  					foundDefault4 = true
    76  				}
    77  				route.NextHop = netip.IPv4Unspecified()
    78  			} else if allowedip.Addr().Is6() {
    79  				if allowedip.Bits() == 0 {
    80  					foundDefault6 = true
    81  				}
    82  				route.NextHop = netip.IPv6Unspecified()
    83  			}
    84  			routes[route] = true
    85  		}
    86  	}
    87  
    88  	deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
    89  	for route := range routes {
    90  		r := route
    91  		deduplicatedRoutes = append(deduplicatedRoutes, &r)
    92  	}
    93  
    94  	if !conf.Interface.TableOff {
    95  		err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
    96  		if err == windows.ERROR_NOT_FOUND && retryOnFailure {
    97  			goto startOver
    98  		} else if err != nil {
    99  			return fmt.Errorf("unable to set routes: %w", err)
   100  		}
   101  	}
   102  
   103  	err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
   104  	if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
   105  		cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses)
   106  		err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
   107  	}
   108  	if err == windows.ERROR_NOT_FOUND && retryOnFailure {
   109  		goto startOver
   110  	} else if err != nil {
   111  		return fmt.Errorf("unable to set ips: %w", err)
   112  	}
   113  
   114  	var ipif *winipcfg.MibIPInterfaceRow
   115  	ipif, err = luid.IPInterface(family)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
   120  	ipif.DadTransmits = 0
   121  	ipif.ManagedAddressConfigurationSupported = false
   122  	ipif.OtherStatefulConfigurationSupported = false
   123  	if conf.Interface.MTU > 0 {
   124  		ipif.NLMTU = uint32(conf.Interface.MTU)
   125  	}
   126  	if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
   127  		ipif.UseAutomaticMetric = false
   128  		ipif.Metric = 0
   129  	}
   130  	err = ipif.Set()
   131  	if err == windows.ERROR_NOT_FOUND && retryOnFailure {
   132  		goto startOver
   133  	} else if err != nil {
   134  		return fmt.Errorf("unable to set metric and MTU: %w", err)
   135  	}
   136  
   137  	err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
   138  	if err == windows.ERROR_NOT_FOUND && retryOnFailure {
   139  		goto startOver
   140  	} else if err != nil {
   141  		return fmt.Errorf("unable to set DNS: %w", err)
   142  	}
   143  	return nil
   144  }
   145  
   146  func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
   147  	doNotRestrict := true
   148  	if len(conf.Peers) == 1 && !conf.Interface.TableOff {
   149  		for _, allowedip := range conf.Peers[0].AllowedIPs {
   150  			if allowedip.Bits() == 0 && allowedip == allowedip.Masked() {
   151  				doNotRestrict = false
   152  				break
   153  			}
   154  		}
   155  	}
   156  	log.Println("Enabling firewall rules")
   157  	return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
   158  }