golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/interfacewatcher.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tunnel
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"log"
    12  	"sync"
    13  	"time"
    14  
    15  	"golang.org/x/sys/windows"
    16  	"golang.zx2c4.com/wireguard/windows/conf"
    17  	"golang.zx2c4.com/wireguard/windows/driver"
    18  	"golang.zx2c4.com/wireguard/windows/services"
    19  	"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
    20  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    21  )
    22  
    23  type interfaceWatcherError struct {
    24  	serviceError services.Error
    25  	err          error
    26  }
    27  
    28  type interfaceWatcherEvent struct {
    29  	luid   winipcfg.LUID
    30  	family winipcfg.AddressFamily
    31  }
    32  
    33  type interfaceWatcher struct {
    34  	errors  chan interfaceWatcherError
    35  	started chan winipcfg.AddressFamily
    36  
    37  	conf    *conf.Config
    38  	adapter *driver.Adapter
    39  	luid    winipcfg.LUID
    40  
    41  	setupMutex              sync.Mutex
    42  	interfaceChangeCallback winipcfg.ChangeCallback
    43  	changeCallbacks4        []winipcfg.ChangeCallback
    44  	changeCallbacks6        []winipcfg.ChangeCallback
    45  	storedEvents            []interfaceWatcherEvent
    46  	watchdog                *time.Timer
    47  }
    48  
    49  func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
    50  	iw.watchdog.Stop()
    51  	var changeCallbacks *[]winipcfg.ChangeCallback
    52  	var ipversion string
    53  	if family == windows.AF_INET {
    54  		changeCallbacks = &iw.changeCallbacks4
    55  		ipversion = "v4"
    56  	} else if family == windows.AF_INET6 {
    57  		changeCallbacks = &iw.changeCallbacks6
    58  		ipversion = "v6"
    59  	} else {
    60  		return
    61  	}
    62  	if len(*changeCallbacks) != 0 {
    63  		for _, cb := range *changeCallbacks {
    64  			cb.Unregister()
    65  		}
    66  		*changeCallbacks = nil
    67  	}
    68  	var err error
    69  
    70  	if iw.conf.Interface.MTU == 0 {
    71  		log.Printf("Monitoring MTU of default %s routes", ipversion)
    72  		*changeCallbacks, err = monitorMTU(family, iw.luid)
    73  		if err != nil {
    74  			iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
    75  			return
    76  		}
    77  	}
    78  
    79  	log.Printf("Setting device %s addresses", ipversion)
    80  	err = configureInterface(family, iw.conf, iw.luid)
    81  	if err != nil {
    82  		iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
    83  		return
    84  	}
    85  	evaluateDynamicPitfalls(family, iw.conf, iw.luid)
    86  
    87  	iw.started <- family
    88  }
    89  
    90  func watchInterface() (*interfaceWatcher, error) {
    91  	iw := &interfaceWatcher{
    92  		errors:  make(chan interfaceWatcherError, 2),
    93  		started: make(chan winipcfg.AddressFamily, 4),
    94  	}
    95  	iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() {
    96  		iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")}
    97  	})
    98  	iw.watchdog.Stop()
    99  	var err error
   100  	iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
   101  		iw.setupMutex.Lock()
   102  		defer iw.setupMutex.Unlock()
   103  
   104  		if notificationType != winipcfg.MibAddInstance {
   105  			return
   106  		}
   107  		if iw.luid == 0 {
   108  			iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
   109  			return
   110  		}
   111  		if iface.InterfaceLUID != iw.luid {
   112  			return
   113  		}
   114  		iw.setup(iface.Family)
   115  
   116  		if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
   117  			log.Println("Reinitializing adapter configuration")
   118  			err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
   119  			if err != nil {
   120  				log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
   121  			}
   122  			err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
   123  			if err != nil {
   124  				log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
   125  			}
   126  		}
   127  	})
   128  	if err != nil {
   129  		return nil, fmt.Errorf("unable to register interface change callback: %w", err)
   130  	}
   131  	return iw, nil
   132  }
   133  
   134  func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
   135  	iw.setupMutex.Lock()
   136  	defer iw.setupMutex.Unlock()
   137  	iw.watchdog.Reset(time.Minute)
   138  
   139  	iw.adapter, iw.conf, iw.luid = adapter, conf, luid
   140  	for _, event := range iw.storedEvents {
   141  		if event.luid == luid {
   142  			iw.setup(event.family)
   143  		}
   144  	}
   145  	iw.storedEvents = nil
   146  }
   147  
   148  func (iw *interfaceWatcher) Destroy() {
   149  	iw.setupMutex.Lock()
   150  	iw.watchdog.Stop()
   151  	changeCallbacks4 := iw.changeCallbacks4
   152  	changeCallbacks6 := iw.changeCallbacks6
   153  	interfaceChangeCallback := iw.interfaceChangeCallback
   154  	luid := iw.luid
   155  	iw.setupMutex.Unlock()
   156  
   157  	if interfaceChangeCallback != nil {
   158  		interfaceChangeCallback.Unregister()
   159  	}
   160  	for _, cb := range changeCallbacks4 {
   161  		cb.Unregister()
   162  	}
   163  	for _, cb := range changeCallbacks6 {
   164  		cb.Unregister()
   165  	}
   166  
   167  	iw.setupMutex.Lock()
   168  	if interfaceChangeCallback == iw.interfaceChangeCallback {
   169  		iw.interfaceChangeCallback = nil
   170  	}
   171  	for len(changeCallbacks4) > 0 && len(iw.changeCallbacks4) > 0 {
   172  		iw.changeCallbacks4 = iw.changeCallbacks4[1:]
   173  		changeCallbacks4 = changeCallbacks4[1:]
   174  	}
   175  	for len(changeCallbacks6) > 0 && len(iw.changeCallbacks6) > 0 {
   176  		iw.changeCallbacks6 = iw.changeCallbacks6[1:]
   177  		changeCallbacks6 = changeCallbacks6[1:]
   178  	}
   179  	firewall.DisableFirewall()
   180  	if luid != 0 && iw.luid == luid {
   181  		// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
   182  		// routes, so to be certain, just remove everything before destroying.
   183  		luid.FlushRoutes(windows.AF_INET)
   184  		luid.FlushIPAddresses(windows.AF_INET)
   185  		luid.FlushDNS(windows.AF_INET)
   186  		luid.FlushRoutes(windows.AF_INET6)
   187  		luid.FlushIPAddresses(windows.AF_INET6)
   188  		luid.FlushDNS(windows.AF_INET6)
   189  	}
   190  	iw.setupMutex.Unlock()
   191  }