golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/tunneltracker.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package manager
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"log"
    12  	"runtime"
    13  	"sync"
    14  	"sync/atomic"
    15  	"syscall"
    16  	"time"
    17  	"unsafe"
    18  
    19  	"golang.org/x/sys/windows"
    20  	"golang.org/x/sys/windows/svc"
    21  	"golang.org/x/sys/windows/svc/mgr"
    22  
    23  	"golang.zx2c4.com/wireguard/windows/conf"
    24  	"golang.zx2c4.com/wireguard/windows/services"
    25  )
    26  
    27  var (
    28  	trackedTunnels     = make(map[string]TunnelState)
    29  	trackedTunnelsLock = sync.Mutex{}
    30  )
    31  
    32  func trackedTunnelsGlobalState() (state TunnelState) {
    33  	state = TunnelStopped
    34  	trackedTunnelsLock.Lock()
    35  	defer trackedTunnelsLock.Unlock()
    36  	for _, s := range trackedTunnels {
    37  		if s == TunnelStarting {
    38  			return TunnelStarting
    39  		} else if s == TunnelStopping {
    40  			return TunnelStopping
    41  		} else if s == TunnelStarted || s == TunnelUnknown {
    42  			state = TunnelStarted
    43  		}
    44  	}
    45  	return
    46  }
    47  
    48  var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
    49  	return 0
    50  })
    51  
    52  type serviceSubscriptionState struct {
    53  	service *mgr.Service
    54  	cb      func(status uint32) bool
    55  	done    sync.WaitGroup
    56  	once    uint32
    57  }
    58  
    59  var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
    60  	state := (*serviceSubscriptionState)(unsafe.Pointer(context))
    61  	if atomic.LoadUint32(&state.once) != 0 {
    62  		return 0
    63  	}
    64  	if notification == 0 {
    65  		status, err := state.service.Query()
    66  		if err == nil {
    67  			notification = svcStateToNotifyState(uint32(status.State))
    68  		}
    69  	}
    70  	if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) {
    71  		state.done.Done()
    72  	}
    73  	return 0
    74  })
    75  
    76  func svcStateToNotifyState(s uint32) uint32 {
    77  	switch s {
    78  	case windows.SERVICE_STOPPED:
    79  		return windows.SERVICE_NOTIFY_STOPPED
    80  	case windows.SERVICE_START_PENDING:
    81  		return windows.SERVICE_NOTIFY_START_PENDING
    82  	case windows.SERVICE_STOP_PENDING:
    83  		return windows.SERVICE_NOTIFY_STOP_PENDING
    84  	case windows.SERVICE_RUNNING:
    85  		return windows.SERVICE_NOTIFY_RUNNING
    86  	case windows.SERVICE_CONTINUE_PENDING:
    87  		return windows.SERVICE_NOTIFY_CONTINUE_PENDING
    88  	case windows.SERVICE_PAUSE_PENDING:
    89  		return windows.SERVICE_NOTIFY_PAUSE_PENDING
    90  	case windows.SERVICE_PAUSED:
    91  		return windows.SERVICE_NOTIFY_PAUSED
    92  	case windows.SERVICE_NO_CHANGE:
    93  		return 0
    94  	default:
    95  		return 0
    96  	}
    97  }
    98  
    99  func notifyStateToTunState(s uint32) TunnelState {
   100  	if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 {
   101  		return TunnelStopped
   102  	} else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 {
   103  		return TunnelStopping
   104  	} else if s&windows.SERVICE_NOTIFY_RUNNING != 0 {
   105  		return TunnelStarted
   106  	} else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 {
   107  		return TunnelStarting
   108  	} else {
   109  		return TunnelUnknown
   110  	}
   111  }
   112  
   113  func trackService(service *mgr.Service, callback func(status uint32) bool) error {
   114  	var subscription uintptr
   115  	state := &serviceSubscriptionState{service: service, cb: callback}
   116  	state.done.Add(1)
   117  	err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription)
   118  	if err == nil {
   119  		defer windows.UnsubscribeServiceChangeNotifications(subscription)
   120  		status, err := service.Query()
   121  		if err == nil {
   122  			if callback(svcStateToNotifyState(uint32(status.State))) {
   123  				return nil
   124  			}
   125  		}
   126  		state.done.Wait()
   127  		runtime.KeepAlive(state.cb)
   128  		return nil
   129  	}
   130  	if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
   131  		return err
   132  	}
   133  
   134  	// TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
   135  
   136  	runtime.LockOSThread()
   137  	// This line would be fitting but is intentionally commented out:
   138  	//
   139  	//     defer runtime.UnlockOSThread()
   140  	//
   141  	// The reason is that NotifyServiceStatusChange used queued APC, which winds up messing
   142  	// with the thread local context, which in turn appears to corrupt Go's own usage of TLS,
   143  	// leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled.
   144  
   145  	const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
   146  	notifier := &windows.SERVICE_NOTIFY{
   147  		Version:        windows.SERVICE_NOTIFY_STATUS_CHANGE,
   148  		NotifyCallback: serviceTrackerCallbackPtr,
   149  	}
   150  	for {
   151  		err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
   152  		switch err {
   153  		case nil:
   154  			for {
   155  				if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
   156  					break
   157  				} else if callback(0) {
   158  					return nil
   159  				}
   160  			}
   161  		case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
   162  			// Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes.
   163  			if callback(windows.SERVICE_NOTIFY_DELETED) {
   164  				return nil
   165  			}
   166  		case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
   167  			continue
   168  		default:
   169  			return err
   170  		}
   171  		if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) {
   172  			return nil
   173  		}
   174  	}
   175  }
   176  
   177  func trackTunnelService(tunnelName string, service *mgr.Service) {
   178  	trackedTunnelsLock.Lock()
   179  	if _, found := trackedTunnels[tunnelName]; found {
   180  		trackedTunnelsLock.Unlock()
   181  		service.Close()
   182  		return
   183  	}
   184  
   185  	defer func() {
   186  		service.Close()
   187  		log.Printf("[%s] Tunnel service tracker finished", tunnelName)
   188  	}()
   189  	trackedTunnels[tunnelName] = TunnelUnknown
   190  	trackedTunnelsLock.Unlock()
   191  	defer func() {
   192  		trackedTunnelsLock.Lock()
   193  		delete(trackedTunnels, tunnelName)
   194  		trackedTunnelsLock.Unlock()
   195  	}()
   196  
   197  	for i := 0; i < 20; i++ {
   198  		if i > 0 {
   199  			time.Sleep(time.Second / 5)
   200  		}
   201  		if status, err := service.Query(); err != nil || status.State != svc.Stopped {
   202  			break
   203  		}
   204  	}
   205  
   206  	checkForDisabled := func() (shouldReturn bool) {
   207  		config, err := service.Config()
   208  		if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
   209  			log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
   210  			service.Delete()
   211  			trackedTunnelsLock.Lock()
   212  			trackedTunnels[tunnelName] = TunnelStopped
   213  			trackedTunnelsLock.Unlock()
   214  			IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
   215  			return true
   216  		}
   217  		return false
   218  	}
   219  	if checkForDisabled() {
   220  		return
   221  	}
   222  	lastState := TunnelUnknown
   223  	err := trackService(service, func(status uint32) bool {
   224  		state := notifyStateToTunState(status)
   225  		var tunnelError error
   226  		if state == TunnelStopped {
   227  			serviceStatus, err := service.Query()
   228  			if err == nil {
   229  				if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
   230  					maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode)
   231  					if maybeErr != services.ErrorSuccess {
   232  						tunnelError = maybeErr
   233  					}
   234  				} else {
   235  					switch serviceStatus.Win32ExitCode {
   236  					case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
   237  					default:
   238  						tunnelError = syscall.Errno(serviceStatus.Win32ExitCode)
   239  					}
   240  				}
   241  			}
   242  			if tunnelError != nil {
   243  				service.Delete()
   244  			}
   245  		}
   246  		if state != lastState {
   247  			trackedTunnelsLock.Lock()
   248  			trackedTunnels[tunnelName] = state
   249  			trackedTunnelsLock.Unlock()
   250  			IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
   251  			lastState = state
   252  		}
   253  		if state == TunnelUnknown && checkForDisabled() {
   254  			return true
   255  		}
   256  		return state == TunnelStopped
   257  	})
   258  	if err != nil && !checkForDisabled() {
   259  		trackedTunnelsLock.Lock()
   260  		trackedTunnels[tunnelName] = TunnelStopped
   261  		trackedTunnelsLock.Unlock()
   262  		IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
   263  		service.Control(svc.Stop)
   264  	}
   265  }
   266  
   267  func trackExistingTunnels() error {
   268  	m, err := serviceManager()
   269  	if err != nil {
   270  		return err
   271  	}
   272  	names, err := conf.ListConfigNames()
   273  	if err != nil {
   274  		return err
   275  	}
   276  	for _, name := range names {
   277  		trackedTunnelsLock.Lock()
   278  		if _, found := trackedTunnels[name]; found {
   279  			trackedTunnelsLock.Unlock()
   280  			continue
   281  		}
   282  		trackedTunnelsLock.Unlock()
   283  		serviceName, err := conf.ServiceNameOfTunnel(name)
   284  		if err != nil {
   285  			continue
   286  		}
   287  		service, err := m.OpenService(serviceName)
   288  		if err != nil {
   289  			continue
   290  		}
   291  		go trackTunnelService(name, service)
   292  	}
   293  	return nil
   294  }
   295  
   296  var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
   297  	trackExistingTunnels()
   298  	return 0
   299  })
   300  
   301  func watchNewTunnelServices() error {
   302  	m, err := serviceManager()
   303  	if err != nil {
   304  		return err
   305  	}
   306  	var subscription uintptr
   307  	err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription)
   308  	if err == nil {
   309  		// We probably could do:
   310  		//     defer windows.UnsubscribeServiceChangeNotifications(subscription)
   311  		// and then terminate after some point, but instead we just let this go forever; it's process-lived.
   312  		return trackExistingTunnels()
   313  	}
   314  	if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
   315  		return err
   316  	}
   317  
   318  	// TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
   319  	go func() {
   320  		runtime.LockOSThread()
   321  		notifier := &windows.SERVICE_NOTIFY{
   322  			Version:        windows.SERVICE_NOTIFY_STATUS_CHANGE,
   323  			NotifyCallback: serviceTrackerCallbackPtr,
   324  		}
   325  		for {
   326  			err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier)
   327  			if err == nil {
   328  				windows.SleepEx(windows.INFINITE, true)
   329  				if notifier.ServiceNames != nil {
   330  					windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames)))
   331  					notifier.ServiceNames = nil
   332  				}
   333  				trackExistingTunnels()
   334  			} else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING {
   335  				continue
   336  			} else {
   337  				time.Sleep(time.Second * 3)
   338  				trackExistingTunnels()
   339  			}
   340  		}
   341  	}()
   342  	return trackExistingTunnels()
   343  }