golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/tunneltracker.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package manager 7 8 import ( 9 "errors" 10 "fmt" 11 "log" 12 "runtime" 13 "sync" 14 "sync/atomic" 15 "syscall" 16 "time" 17 "unsafe" 18 19 "golang.org/x/sys/windows" 20 "golang.org/x/sys/windows/svc" 21 "golang.org/x/sys/windows/svc/mgr" 22 23 "golang.zx2c4.com/wireguard/windows/conf" 24 "golang.zx2c4.com/wireguard/windows/services" 25 ) 26 27 var ( 28 trackedTunnels = make(map[string]TunnelState) 29 trackedTunnelsLock = sync.Mutex{} 30 ) 31 32 func trackedTunnelsGlobalState() (state TunnelState) { 33 state = TunnelStopped 34 trackedTunnelsLock.Lock() 35 defer trackedTunnelsLock.Unlock() 36 for _, s := range trackedTunnels { 37 if s == TunnelStarting { 38 return TunnelStarting 39 } else if s == TunnelStopping { 40 return TunnelStopping 41 } else if s == TunnelStarted || s == TunnelUnknown { 42 state = TunnelStarted 43 } 44 } 45 return 46 } 47 48 var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { 49 return 0 50 }) 51 52 type serviceSubscriptionState struct { 53 service *mgr.Service 54 cb func(status uint32) bool 55 done sync.WaitGroup 56 once uint32 57 } 58 59 var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { 60 state := (*serviceSubscriptionState)(unsafe.Pointer(context)) 61 if atomic.LoadUint32(&state.once) != 0 { 62 return 0 63 } 64 if notification == 0 { 65 status, err := state.service.Query() 66 if err == nil { 67 notification = svcStateToNotifyState(uint32(status.State)) 68 } 69 } 70 if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) { 71 state.done.Done() 72 } 73 return 0 74 }) 75 76 func svcStateToNotifyState(s uint32) uint32 { 77 switch s { 78 case windows.SERVICE_STOPPED: 79 return windows.SERVICE_NOTIFY_STOPPED 80 case windows.SERVICE_START_PENDING: 81 return windows.SERVICE_NOTIFY_START_PENDING 82 case windows.SERVICE_STOP_PENDING: 83 return windows.SERVICE_NOTIFY_STOP_PENDING 84 case windows.SERVICE_RUNNING: 85 return windows.SERVICE_NOTIFY_RUNNING 86 case windows.SERVICE_CONTINUE_PENDING: 87 return windows.SERVICE_NOTIFY_CONTINUE_PENDING 88 case windows.SERVICE_PAUSE_PENDING: 89 return windows.SERVICE_NOTIFY_PAUSE_PENDING 90 case windows.SERVICE_PAUSED: 91 return windows.SERVICE_NOTIFY_PAUSED 92 case windows.SERVICE_NO_CHANGE: 93 return 0 94 default: 95 return 0 96 } 97 } 98 99 func notifyStateToTunState(s uint32) TunnelState { 100 if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 { 101 return TunnelStopped 102 } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 { 103 return TunnelStopping 104 } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 { 105 return TunnelStarted 106 } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 { 107 return TunnelStarting 108 } else { 109 return TunnelUnknown 110 } 111 } 112 113 func trackService(service *mgr.Service, callback func(status uint32) bool) error { 114 var subscription uintptr 115 state := &serviceSubscriptionState{service: service, cb: callback} 116 state.done.Add(1) 117 err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription) 118 if err == nil { 119 defer windows.UnsubscribeServiceChangeNotifications(subscription) 120 status, err := service.Query() 121 if err == nil { 122 if callback(svcStateToNotifyState(uint32(status.State))) { 123 return nil 124 } 125 } 126 state.done.Wait() 127 runtime.KeepAlive(state.cb) 128 return nil 129 } 130 if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { 131 return err 132 } 133 134 // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. 135 136 runtime.LockOSThread() 137 // This line would be fitting but is intentionally commented out: 138 // 139 // defer runtime.UnlockOSThread() 140 // 141 // The reason is that NotifyServiceStatusChange used queued APC, which winds up messing 142 // with the thread local context, which in turn appears to corrupt Go's own usage of TLS, 143 // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled. 144 145 const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING 146 notifier := &windows.SERVICE_NOTIFY{ 147 Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, 148 NotifyCallback: serviceTrackerCallbackPtr, 149 } 150 for { 151 err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) 152 switch err { 153 case nil: 154 for { 155 if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { 156 break 157 } else if callback(0) { 158 return nil 159 } 160 } 161 case windows.ERROR_SERVICE_MARKED_FOR_DELETE: 162 // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes. 163 if callback(windows.SERVICE_NOTIFY_DELETED) { 164 return nil 165 } 166 case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: 167 continue 168 default: 169 return err 170 } 171 if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) { 172 return nil 173 } 174 } 175 } 176 177 func trackTunnelService(tunnelName string, service *mgr.Service) { 178 trackedTunnelsLock.Lock() 179 if _, found := trackedTunnels[tunnelName]; found { 180 trackedTunnelsLock.Unlock() 181 service.Close() 182 return 183 } 184 185 defer func() { 186 service.Close() 187 log.Printf("[%s] Tunnel service tracker finished", tunnelName) 188 }() 189 trackedTunnels[tunnelName] = TunnelUnknown 190 trackedTunnelsLock.Unlock() 191 defer func() { 192 trackedTunnelsLock.Lock() 193 delete(trackedTunnels, tunnelName) 194 trackedTunnelsLock.Unlock() 195 }() 196 197 for i := 0; i < 20; i++ { 198 if i > 0 { 199 time.Sleep(time.Second / 5) 200 } 201 if status, err := service.Query(); err != nil || status.State != svc.Stopped { 202 break 203 } 204 } 205 206 checkForDisabled := func() (shouldReturn bool) { 207 config, err := service.Config() 208 if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { 209 log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) 210 service.Delete() 211 trackedTunnelsLock.Lock() 212 trackedTunnels[tunnelName] = TunnelStopped 213 trackedTunnelsLock.Unlock() 214 IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) 215 return true 216 } 217 return false 218 } 219 if checkForDisabled() { 220 return 221 } 222 lastState := TunnelUnknown 223 err := trackService(service, func(status uint32) bool { 224 state := notifyStateToTunState(status) 225 var tunnelError error 226 if state == TunnelStopped { 227 serviceStatus, err := service.Query() 228 if err == nil { 229 if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { 230 maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode) 231 if maybeErr != services.ErrorSuccess { 232 tunnelError = maybeErr 233 } 234 } else { 235 switch serviceStatus.Win32ExitCode { 236 case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): 237 default: 238 tunnelError = syscall.Errno(serviceStatus.Win32ExitCode) 239 } 240 } 241 } 242 if tunnelError != nil { 243 service.Delete() 244 } 245 } 246 if state != lastState { 247 trackedTunnelsLock.Lock() 248 trackedTunnels[tunnelName] = state 249 trackedTunnelsLock.Unlock() 250 IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) 251 lastState = state 252 } 253 if state == TunnelUnknown && checkForDisabled() { 254 return true 255 } 256 return state == TunnelStopped 257 }) 258 if err != nil && !checkForDisabled() { 259 trackedTunnelsLock.Lock() 260 trackedTunnels[tunnelName] = TunnelStopped 261 trackedTunnelsLock.Unlock() 262 IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) 263 service.Control(svc.Stop) 264 } 265 } 266 267 func trackExistingTunnels() error { 268 m, err := serviceManager() 269 if err != nil { 270 return err 271 } 272 names, err := conf.ListConfigNames() 273 if err != nil { 274 return err 275 } 276 for _, name := range names { 277 trackedTunnelsLock.Lock() 278 if _, found := trackedTunnels[name]; found { 279 trackedTunnelsLock.Unlock() 280 continue 281 } 282 trackedTunnelsLock.Unlock() 283 serviceName, err := conf.ServiceNameOfTunnel(name) 284 if err != nil { 285 continue 286 } 287 service, err := m.OpenService(serviceName) 288 if err != nil { 289 continue 290 } 291 go trackTunnelService(name, service) 292 } 293 return nil 294 } 295 296 var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { 297 trackExistingTunnels() 298 return 0 299 }) 300 301 func watchNewTunnelServices() error { 302 m, err := serviceManager() 303 if err != nil { 304 return err 305 } 306 var subscription uintptr 307 err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) 308 if err == nil { 309 // We probably could do: 310 // defer windows.UnsubscribeServiceChangeNotifications(subscription) 311 // and then terminate after some point, but instead we just let this go forever; it's process-lived. 312 return trackExistingTunnels() 313 } 314 if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { 315 return err 316 } 317 318 // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. 319 go func() { 320 runtime.LockOSThread() 321 notifier := &windows.SERVICE_NOTIFY{ 322 Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, 323 NotifyCallback: serviceTrackerCallbackPtr, 324 } 325 for { 326 err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) 327 if err == nil { 328 windows.SleepEx(windows.INFINITE, true) 329 if notifier.ServiceNames != nil { 330 windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) 331 notifier.ServiceNames = nil 332 } 333 trackExistingTunnels() 334 } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { 335 continue 336 } else { 337 time.Sleep(time.Second * 3) 338 trackExistingTunnels() 339 } 340 } 341 }() 342 return trackExistingTunnels() 343 }