golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/ipc_server.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package manager 7 8 import ( 9 "bytes" 10 "encoding/gob" 11 "fmt" 12 "io" 13 "log" 14 "os" 15 "sync" 16 "sync/atomic" 17 "time" 18 19 "golang.org/x/sys/windows" 20 "golang.org/x/sys/windows/svc" 21 22 "golang.zx2c4.com/wireguard/windows/conf" 23 "golang.zx2c4.com/wireguard/windows/updater" 24 ) 25 26 var ( 27 managerServices = make(map[*ManagerService]bool) 28 managerServicesLock sync.RWMutex 29 haveQuit uint32 30 quitManagersChan = make(chan struct{}, 1) 31 ) 32 33 type ManagerService struct { 34 events *os.File 35 eventLock sync.Mutex 36 elevatedToken windows.Token 37 } 38 39 func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { 40 conf, err := conf.LoadFromName(tunnelName) 41 if err != nil { 42 return nil, err 43 } 44 if s.elevatedToken == 0 { 45 conf.Redact() 46 } 47 return conf, nil 48 } 49 50 func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { 51 storedConfig, err := conf.LoadFromName(tunnelName) 52 if err != nil { 53 return nil, err 54 } 55 driverAdapter, err := findDriverAdapter(tunnelName) 56 if err != nil { 57 return nil, err 58 } 59 runtimeConfig, err := driverAdapter.Configuration() 60 if err != nil { 61 driverAdapter.Unlock() 62 releaseDriverAdapter(tunnelName) 63 return nil, err 64 } 65 conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) 66 driverAdapter.Unlock() 67 if s.elevatedToken == 0 { 68 conf.Redact() 69 } 70 return conf, nil 71 } 72 73 func (s *ManagerService) Start(tunnelName string) error { 74 c, err := conf.LoadFromName(tunnelName) 75 if err != nil { 76 return err 77 } 78 79 // Figure out which tunnels have intersecting addresses/routes and stop those. 80 trackedTunnelsLock.Lock() 81 tt := make([]string, 0, len(trackedTunnels)) 82 var inTransition string 83 for t, state := range trackedTunnels { 84 c2, err := conf.LoadFromName(t) 85 if err != nil || !c.IntersectsWith(c2) { 86 // If we can't get the config, assume it doesn't intersect. 87 continue 88 } 89 tt = append(tt, t) 90 if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { 91 inTransition = t 92 break 93 } 94 } 95 trackedTunnelsLock.Unlock() 96 if len(inTransition) != 0 { 97 return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition) 98 } 99 100 // Stop those intersecting tunnels asynchronously. 101 go func() { 102 for _, t := range tt { 103 s.Stop(t) 104 } 105 for _, t := range tt { 106 state, err := s.State(t) 107 if err == nil && (state == TunnelStarted || state == TunnelStarting) { 108 log.Printf("[%s] Trying again to stop zombie tunnel", t) 109 s.Stop(t) 110 time.Sleep(time.Millisecond * 100) 111 } 112 } 113 }() 114 // After the stop process has begun, but before it's finished, we install the new one. 115 path, err := c.Path() 116 if err != nil { 117 return err 118 } 119 return InstallTunnel(path) 120 } 121 122 func (s *ManagerService) Stop(tunnelName string) error { 123 err := UninstallTunnel(tunnelName) 124 if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { 125 _, notExistsError := conf.LoadFromName(tunnelName) 126 if notExistsError == nil { 127 return nil 128 } 129 } 130 return err 131 } 132 133 func (s *ManagerService) WaitForStop(tunnelName string) error { 134 serviceName, err := conf.ServiceNameOfTunnel(tunnelName) 135 if err != nil { 136 return err 137 } 138 m, err := serviceManager() 139 if err != nil { 140 return err 141 } 142 for { 143 service, err := m.OpenService(serviceName) 144 if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { 145 service.Close() 146 time.Sleep(time.Second / 3) 147 } else { 148 return nil 149 } 150 } 151 } 152 153 func (s *ManagerService) Delete(tunnelName string) error { 154 if s.elevatedToken == 0 { 155 return windows.ERROR_ACCESS_DENIED 156 } 157 err := s.Stop(tunnelName) 158 if err != nil { 159 return err 160 } 161 return conf.DeleteName(tunnelName) 162 } 163 164 func (s *ManagerService) State(tunnelName string) (TunnelState, error) { 165 serviceName, err := conf.ServiceNameOfTunnel(tunnelName) 166 if err != nil { 167 return 0, err 168 } 169 m, err := serviceManager() 170 if err != nil { 171 return 0, err 172 } 173 service, err := m.OpenService(serviceName) 174 if err != nil { 175 return TunnelStopped, nil 176 } 177 defer service.Close() 178 status, err := service.Query() 179 if err != nil { 180 return TunnelUnknown, nil 181 } 182 switch status.State { 183 case svc.Stopped: 184 return TunnelStopped, nil 185 case svc.StopPending: 186 return TunnelStopping, nil 187 case svc.Running: 188 return TunnelStarted, nil 189 case svc.StartPending: 190 return TunnelStarting, nil 191 default: 192 return TunnelUnknown, nil 193 } 194 } 195 196 func (s *ManagerService) GlobalState() TunnelState { 197 return trackedTunnelsGlobalState() 198 } 199 200 func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { 201 if s.elevatedToken == 0 { 202 return nil, windows.ERROR_ACCESS_DENIED 203 } 204 err := tunnelConfig.Save(true) 205 if err != nil { 206 return nil, err 207 } 208 return &Tunnel{tunnelConfig.Name}, nil 209 // TODO: handle already existing situation 210 // TODO: handle already running and existing situation 211 } 212 213 func (s *ManagerService) Tunnels() ([]Tunnel, error) { 214 names, err := conf.ListConfigNames() 215 if err != nil { 216 return nil, err 217 } 218 tunnels := make([]Tunnel, len(names)) 219 for i := 0; i < len(tunnels); i++ { 220 tunnels[i].Name = names[i] 221 } 222 return tunnels, nil 223 // TODO: account for running ones that aren't in the configuration store somehow 224 } 225 226 func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { 227 if s.elevatedToken == 0 { 228 return false, windows.ERROR_ACCESS_DENIED 229 } 230 if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { 231 return true, nil 232 } 233 234 // Work around potential race condition of delivering messages to the wrong process by removing from notifications. 235 managerServicesLock.Lock() 236 s.eventLock.Lock() 237 s.events = nil 238 s.eventLock.Unlock() 239 delete(managerServices, s) 240 managerServicesLock.Unlock() 241 242 if stopTunnelsOnQuit { 243 names, err := conf.ListConfigNames() 244 if err != nil { 245 return false, err 246 } 247 for _, name := range names { 248 UninstallTunnel(name) 249 } 250 } 251 252 quitManagersChan <- struct{}{} 253 return false, nil 254 } 255 256 func (s *ManagerService) UpdateState() UpdateState { 257 return updateState 258 } 259 260 func (s *ManagerService) Update() { 261 if s.elevatedToken == 0 { 262 return 263 } 264 progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken)) 265 go func() { 266 for { 267 dp := <-progress 268 IPCServerNotifyUpdateProgress(dp) 269 if dp.Complete || dp.Error != nil { 270 return 271 } 272 } 273 }() 274 } 275 276 func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { 277 decoder := gob.NewDecoder(reader) 278 encoder := gob.NewEncoder(writer) 279 for { 280 var methodType MethodType 281 err := decoder.Decode(&methodType) 282 if err != nil { 283 return 284 } 285 switch methodType { 286 case StoredConfigMethodType: 287 var tunnelName string 288 err := decoder.Decode(&tunnelName) 289 if err != nil { 290 return 291 } 292 config, retErr := s.StoredConfig(tunnelName) 293 if config == nil { 294 config = &conf.Config{} 295 } 296 err = encoder.Encode(*config) 297 if err != nil { 298 return 299 } 300 err = encoder.Encode(errToString(retErr)) 301 if err != nil { 302 return 303 } 304 case RuntimeConfigMethodType: 305 var tunnelName string 306 err := decoder.Decode(&tunnelName) 307 if err != nil { 308 return 309 } 310 config, retErr := s.RuntimeConfig(tunnelName) 311 if config == nil { 312 config = &conf.Config{} 313 } 314 err = encoder.Encode(*config) 315 if err != nil { 316 return 317 } 318 err = encoder.Encode(errToString(retErr)) 319 if err != nil { 320 return 321 } 322 case StartMethodType: 323 var tunnelName string 324 err := decoder.Decode(&tunnelName) 325 if err != nil { 326 return 327 } 328 retErr := s.Start(tunnelName) 329 err = encoder.Encode(errToString(retErr)) 330 if err != nil { 331 return 332 } 333 case StopMethodType: 334 var tunnelName string 335 err := decoder.Decode(&tunnelName) 336 if err != nil { 337 return 338 } 339 retErr := s.Stop(tunnelName) 340 err = encoder.Encode(errToString(retErr)) 341 if err != nil { 342 return 343 } 344 case WaitForStopMethodType: 345 var tunnelName string 346 err := decoder.Decode(&tunnelName) 347 if err != nil { 348 return 349 } 350 retErr := s.WaitForStop(tunnelName) 351 err = encoder.Encode(errToString(retErr)) 352 if err != nil { 353 return 354 } 355 case DeleteMethodType: 356 var tunnelName string 357 err := decoder.Decode(&tunnelName) 358 if err != nil { 359 return 360 } 361 retErr := s.Delete(tunnelName) 362 err = encoder.Encode(errToString(retErr)) 363 if err != nil { 364 return 365 } 366 case StateMethodType: 367 var tunnelName string 368 err := decoder.Decode(&tunnelName) 369 if err != nil { 370 return 371 } 372 state, retErr := s.State(tunnelName) 373 err = encoder.Encode(state) 374 if err != nil { 375 return 376 } 377 err = encoder.Encode(errToString(retErr)) 378 if err != nil { 379 return 380 } 381 case GlobalStateMethodType: 382 state := s.GlobalState() 383 err = encoder.Encode(state) 384 if err != nil { 385 return 386 } 387 case CreateMethodType: 388 var config conf.Config 389 err := decoder.Decode(&config) 390 if err != nil { 391 return 392 } 393 tunnel, retErr := s.Create(&config) 394 if tunnel == nil { 395 tunnel = &Tunnel{} 396 } 397 err = encoder.Encode(tunnel) 398 if err != nil { 399 return 400 } 401 err = encoder.Encode(errToString(retErr)) 402 if err != nil { 403 return 404 } 405 case TunnelsMethodType: 406 tunnels, retErr := s.Tunnels() 407 err = encoder.Encode(tunnels) 408 if err != nil { 409 return 410 } 411 err = encoder.Encode(errToString(retErr)) 412 if err != nil { 413 return 414 } 415 case QuitMethodType: 416 var stopTunnelsOnQuit bool 417 err := decoder.Decode(&stopTunnelsOnQuit) 418 if err != nil { 419 return 420 } 421 alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit) 422 err = encoder.Encode(alreadyQuit) 423 if err != nil { 424 return 425 } 426 err = encoder.Encode(errToString(retErr)) 427 if err != nil { 428 return 429 } 430 case UpdateStateMethodType: 431 updateState := s.UpdateState() 432 err = encoder.Encode(updateState) 433 if err != nil { 434 return 435 } 436 case UpdateMethodType: 437 s.Update() 438 default: 439 return 440 } 441 } 442 } 443 444 func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { 445 service := &ManagerService{ 446 events: events, 447 elevatedToken: elevatedToken, 448 } 449 450 go func() { 451 managerServicesLock.Lock() 452 managerServices[service] = true 453 managerServicesLock.Unlock() 454 service.ServeConn(reader, writer) 455 managerServicesLock.Lock() 456 service.eventLock.Lock() 457 service.events = nil 458 service.eventLock.Unlock() 459 delete(managerServices, service) 460 managerServicesLock.Unlock() 461 }() 462 } 463 464 func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) { 465 if len(managerServices) == 0 { 466 return 467 } 468 469 var buf bytes.Buffer 470 encoder := gob.NewEncoder(&buf) 471 err := encoder.Encode(notificationType) 472 if err != nil { 473 return 474 } 475 for _, iface := range ifaces { 476 err = encoder.Encode(iface) 477 if err != nil { 478 return 479 } 480 } 481 482 managerServicesLock.RLock() 483 for m := range managerServices { 484 if m.elevatedToken == 0 && adminOnly { 485 continue 486 } 487 go func(m *ManagerService) { 488 m.eventLock.Lock() 489 defer m.eventLock.Unlock() 490 if m.events != nil { 491 m.events.SetWriteDeadline(time.Now().Add(time.Second)) 492 m.events.Write(buf.Bytes()) 493 } 494 }(m) 495 } 496 managerServicesLock.RUnlock() 497 } 498 499 func errToString(err error) string { 500 if err == nil { 501 return "" 502 } 503 return err.Error() 504 } 505 506 func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { 507 notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err)) 508 } 509 510 func IPCServerNotifyTunnelsChange() { 511 notifyAll(TunnelsChangeNotificationType, false) 512 } 513 514 func IPCServerNotifyUpdateFound(state UpdateState) { 515 notifyAll(UpdateFoundNotificationType, false, state) 516 } 517 518 func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) { 519 notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete) 520 } 521 522 func IPCServerNotifyManagerStopping() { 523 notifyAll(ManagerStoppingNotificationType, false) 524 time.Sleep(time.Millisecond * 200) 525 }