golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/ipc_server.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package manager
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/gob"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"os"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"golang.org/x/sys/windows"
    20  	"golang.org/x/sys/windows/svc"
    21  
    22  	"golang.zx2c4.com/wireguard/windows/conf"
    23  	"golang.zx2c4.com/wireguard/windows/updater"
    24  )
    25  
    26  var (
    27  	managerServices     = make(map[*ManagerService]bool)
    28  	managerServicesLock sync.RWMutex
    29  	haveQuit            uint32
    30  	quitManagersChan    = make(chan struct{}, 1)
    31  )
    32  
    33  type ManagerService struct {
    34  	events        *os.File
    35  	eventLock     sync.Mutex
    36  	elevatedToken windows.Token
    37  }
    38  
    39  func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
    40  	conf, err := conf.LoadFromName(tunnelName)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	if s.elevatedToken == 0 {
    45  		conf.Redact()
    46  	}
    47  	return conf, nil
    48  }
    49  
    50  func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
    51  	storedConfig, err := conf.LoadFromName(tunnelName)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	driverAdapter, err := findDriverAdapter(tunnelName)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	runtimeConfig, err := driverAdapter.Configuration()
    60  	if err != nil {
    61  		driverAdapter.Unlock()
    62  		releaseDriverAdapter(tunnelName)
    63  		return nil, err
    64  	}
    65  	conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
    66  	driverAdapter.Unlock()
    67  	if s.elevatedToken == 0 {
    68  		conf.Redact()
    69  	}
    70  	return conf, nil
    71  }
    72  
    73  func (s *ManagerService) Start(tunnelName string) error {
    74  	c, err := conf.LoadFromName(tunnelName)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// Figure out which tunnels have intersecting addresses/routes and stop those.
    80  	trackedTunnelsLock.Lock()
    81  	tt := make([]string, 0, len(trackedTunnels))
    82  	var inTransition string
    83  	for t, state := range trackedTunnels {
    84  		c2, err := conf.LoadFromName(t)
    85  		if err != nil || !c.IntersectsWith(c2) {
    86  			// If we can't get the config, assume it doesn't intersect.
    87  			continue
    88  		}
    89  		tt = append(tt, t)
    90  		if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
    91  			inTransition = t
    92  			break
    93  		}
    94  	}
    95  	trackedTunnelsLock.Unlock()
    96  	if len(inTransition) != 0 {
    97  		return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition)
    98  	}
    99  
   100  	// Stop those intersecting tunnels asynchronously.
   101  	go func() {
   102  		for _, t := range tt {
   103  			s.Stop(t)
   104  		}
   105  		for _, t := range tt {
   106  			state, err := s.State(t)
   107  			if err == nil && (state == TunnelStarted || state == TunnelStarting) {
   108  				log.Printf("[%s] Trying again to stop zombie tunnel", t)
   109  				s.Stop(t)
   110  				time.Sleep(time.Millisecond * 100)
   111  			}
   112  		}
   113  	}()
   114  	// After the stop process has begun, but before it's finished, we install the new one.
   115  	path, err := c.Path()
   116  	if err != nil {
   117  		return err
   118  	}
   119  	return InstallTunnel(path)
   120  }
   121  
   122  func (s *ManagerService) Stop(tunnelName string) error {
   123  	err := UninstallTunnel(tunnelName)
   124  	if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
   125  		_, notExistsError := conf.LoadFromName(tunnelName)
   126  		if notExistsError == nil {
   127  			return nil
   128  		}
   129  	}
   130  	return err
   131  }
   132  
   133  func (s *ManagerService) WaitForStop(tunnelName string) error {
   134  	serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
   135  	if err != nil {
   136  		return err
   137  	}
   138  	m, err := serviceManager()
   139  	if err != nil {
   140  		return err
   141  	}
   142  	for {
   143  		service, err := m.OpenService(serviceName)
   144  		if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
   145  			service.Close()
   146  			time.Sleep(time.Second / 3)
   147  		} else {
   148  			return nil
   149  		}
   150  	}
   151  }
   152  
   153  func (s *ManagerService) Delete(tunnelName string) error {
   154  	if s.elevatedToken == 0 {
   155  		return windows.ERROR_ACCESS_DENIED
   156  	}
   157  	err := s.Stop(tunnelName)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	return conf.DeleteName(tunnelName)
   162  }
   163  
   164  func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
   165  	serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
   166  	if err != nil {
   167  		return 0, err
   168  	}
   169  	m, err := serviceManager()
   170  	if err != nil {
   171  		return 0, err
   172  	}
   173  	service, err := m.OpenService(serviceName)
   174  	if err != nil {
   175  		return TunnelStopped, nil
   176  	}
   177  	defer service.Close()
   178  	status, err := service.Query()
   179  	if err != nil {
   180  		return TunnelUnknown, nil
   181  	}
   182  	switch status.State {
   183  	case svc.Stopped:
   184  		return TunnelStopped, nil
   185  	case svc.StopPending:
   186  		return TunnelStopping, nil
   187  	case svc.Running:
   188  		return TunnelStarted, nil
   189  	case svc.StartPending:
   190  		return TunnelStarting, nil
   191  	default:
   192  		return TunnelUnknown, nil
   193  	}
   194  }
   195  
   196  func (s *ManagerService) GlobalState() TunnelState {
   197  	return trackedTunnelsGlobalState()
   198  }
   199  
   200  func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
   201  	if s.elevatedToken == 0 {
   202  		return nil, windows.ERROR_ACCESS_DENIED
   203  	}
   204  	err := tunnelConfig.Save(true)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	return &Tunnel{tunnelConfig.Name}, nil
   209  	// TODO: handle already existing situation
   210  	// TODO: handle already running and existing situation
   211  }
   212  
   213  func (s *ManagerService) Tunnels() ([]Tunnel, error) {
   214  	names, err := conf.ListConfigNames()
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  	tunnels := make([]Tunnel, len(names))
   219  	for i := 0; i < len(tunnels); i++ {
   220  		tunnels[i].Name = names[i]
   221  	}
   222  	return tunnels, nil
   223  	// TODO: account for running ones that aren't in the configuration store somehow
   224  }
   225  
   226  func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
   227  	if s.elevatedToken == 0 {
   228  		return false, windows.ERROR_ACCESS_DENIED
   229  	}
   230  	if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
   231  		return true, nil
   232  	}
   233  
   234  	// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
   235  	managerServicesLock.Lock()
   236  	s.eventLock.Lock()
   237  	s.events = nil
   238  	s.eventLock.Unlock()
   239  	delete(managerServices, s)
   240  	managerServicesLock.Unlock()
   241  
   242  	if stopTunnelsOnQuit {
   243  		names, err := conf.ListConfigNames()
   244  		if err != nil {
   245  			return false, err
   246  		}
   247  		for _, name := range names {
   248  			UninstallTunnel(name)
   249  		}
   250  	}
   251  
   252  	quitManagersChan <- struct{}{}
   253  	return false, nil
   254  }
   255  
   256  func (s *ManagerService) UpdateState() UpdateState {
   257  	return updateState
   258  }
   259  
   260  func (s *ManagerService) Update() {
   261  	if s.elevatedToken == 0 {
   262  		return
   263  	}
   264  	progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
   265  	go func() {
   266  		for {
   267  			dp := <-progress
   268  			IPCServerNotifyUpdateProgress(dp)
   269  			if dp.Complete || dp.Error != nil {
   270  				return
   271  			}
   272  		}
   273  	}()
   274  }
   275  
   276  func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
   277  	decoder := gob.NewDecoder(reader)
   278  	encoder := gob.NewEncoder(writer)
   279  	for {
   280  		var methodType MethodType
   281  		err := decoder.Decode(&methodType)
   282  		if err != nil {
   283  			return
   284  		}
   285  		switch methodType {
   286  		case StoredConfigMethodType:
   287  			var tunnelName string
   288  			err := decoder.Decode(&tunnelName)
   289  			if err != nil {
   290  				return
   291  			}
   292  			config, retErr := s.StoredConfig(tunnelName)
   293  			if config == nil {
   294  				config = &conf.Config{}
   295  			}
   296  			err = encoder.Encode(*config)
   297  			if err != nil {
   298  				return
   299  			}
   300  			err = encoder.Encode(errToString(retErr))
   301  			if err != nil {
   302  				return
   303  			}
   304  		case RuntimeConfigMethodType:
   305  			var tunnelName string
   306  			err := decoder.Decode(&tunnelName)
   307  			if err != nil {
   308  				return
   309  			}
   310  			config, retErr := s.RuntimeConfig(tunnelName)
   311  			if config == nil {
   312  				config = &conf.Config{}
   313  			}
   314  			err = encoder.Encode(*config)
   315  			if err != nil {
   316  				return
   317  			}
   318  			err = encoder.Encode(errToString(retErr))
   319  			if err != nil {
   320  				return
   321  			}
   322  		case StartMethodType:
   323  			var tunnelName string
   324  			err := decoder.Decode(&tunnelName)
   325  			if err != nil {
   326  				return
   327  			}
   328  			retErr := s.Start(tunnelName)
   329  			err = encoder.Encode(errToString(retErr))
   330  			if err != nil {
   331  				return
   332  			}
   333  		case StopMethodType:
   334  			var tunnelName string
   335  			err := decoder.Decode(&tunnelName)
   336  			if err != nil {
   337  				return
   338  			}
   339  			retErr := s.Stop(tunnelName)
   340  			err = encoder.Encode(errToString(retErr))
   341  			if err != nil {
   342  				return
   343  			}
   344  		case WaitForStopMethodType:
   345  			var tunnelName string
   346  			err := decoder.Decode(&tunnelName)
   347  			if err != nil {
   348  				return
   349  			}
   350  			retErr := s.WaitForStop(tunnelName)
   351  			err = encoder.Encode(errToString(retErr))
   352  			if err != nil {
   353  				return
   354  			}
   355  		case DeleteMethodType:
   356  			var tunnelName string
   357  			err := decoder.Decode(&tunnelName)
   358  			if err != nil {
   359  				return
   360  			}
   361  			retErr := s.Delete(tunnelName)
   362  			err = encoder.Encode(errToString(retErr))
   363  			if err != nil {
   364  				return
   365  			}
   366  		case StateMethodType:
   367  			var tunnelName string
   368  			err := decoder.Decode(&tunnelName)
   369  			if err != nil {
   370  				return
   371  			}
   372  			state, retErr := s.State(tunnelName)
   373  			err = encoder.Encode(state)
   374  			if err != nil {
   375  				return
   376  			}
   377  			err = encoder.Encode(errToString(retErr))
   378  			if err != nil {
   379  				return
   380  			}
   381  		case GlobalStateMethodType:
   382  			state := s.GlobalState()
   383  			err = encoder.Encode(state)
   384  			if err != nil {
   385  				return
   386  			}
   387  		case CreateMethodType:
   388  			var config conf.Config
   389  			err := decoder.Decode(&config)
   390  			if err != nil {
   391  				return
   392  			}
   393  			tunnel, retErr := s.Create(&config)
   394  			if tunnel == nil {
   395  				tunnel = &Tunnel{}
   396  			}
   397  			err = encoder.Encode(tunnel)
   398  			if err != nil {
   399  				return
   400  			}
   401  			err = encoder.Encode(errToString(retErr))
   402  			if err != nil {
   403  				return
   404  			}
   405  		case TunnelsMethodType:
   406  			tunnels, retErr := s.Tunnels()
   407  			err = encoder.Encode(tunnels)
   408  			if err != nil {
   409  				return
   410  			}
   411  			err = encoder.Encode(errToString(retErr))
   412  			if err != nil {
   413  				return
   414  			}
   415  		case QuitMethodType:
   416  			var stopTunnelsOnQuit bool
   417  			err := decoder.Decode(&stopTunnelsOnQuit)
   418  			if err != nil {
   419  				return
   420  			}
   421  			alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit)
   422  			err = encoder.Encode(alreadyQuit)
   423  			if err != nil {
   424  				return
   425  			}
   426  			err = encoder.Encode(errToString(retErr))
   427  			if err != nil {
   428  				return
   429  			}
   430  		case UpdateStateMethodType:
   431  			updateState := s.UpdateState()
   432  			err = encoder.Encode(updateState)
   433  			if err != nil {
   434  				return
   435  			}
   436  		case UpdateMethodType:
   437  			s.Update()
   438  		default:
   439  			return
   440  		}
   441  	}
   442  }
   443  
   444  func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
   445  	service := &ManagerService{
   446  		events:        events,
   447  		elevatedToken: elevatedToken,
   448  	}
   449  
   450  	go func() {
   451  		managerServicesLock.Lock()
   452  		managerServices[service] = true
   453  		managerServicesLock.Unlock()
   454  		service.ServeConn(reader, writer)
   455  		managerServicesLock.Lock()
   456  		service.eventLock.Lock()
   457  		service.events = nil
   458  		service.eventLock.Unlock()
   459  		delete(managerServices, service)
   460  		managerServicesLock.Unlock()
   461  	}()
   462  }
   463  
   464  func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
   465  	if len(managerServices) == 0 {
   466  		return
   467  	}
   468  
   469  	var buf bytes.Buffer
   470  	encoder := gob.NewEncoder(&buf)
   471  	err := encoder.Encode(notificationType)
   472  	if err != nil {
   473  		return
   474  	}
   475  	for _, iface := range ifaces {
   476  		err = encoder.Encode(iface)
   477  		if err != nil {
   478  			return
   479  		}
   480  	}
   481  
   482  	managerServicesLock.RLock()
   483  	for m := range managerServices {
   484  		if m.elevatedToken == 0 && adminOnly {
   485  			continue
   486  		}
   487  		go func(m *ManagerService) {
   488  			m.eventLock.Lock()
   489  			defer m.eventLock.Unlock()
   490  			if m.events != nil {
   491  				m.events.SetWriteDeadline(time.Now().Add(time.Second))
   492  				m.events.Write(buf.Bytes())
   493  			}
   494  		}(m)
   495  	}
   496  	managerServicesLock.RUnlock()
   497  }
   498  
   499  func errToString(err error) string {
   500  	if err == nil {
   501  		return ""
   502  	}
   503  	return err.Error()
   504  }
   505  
   506  func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
   507  	notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
   508  }
   509  
   510  func IPCServerNotifyTunnelsChange() {
   511  	notifyAll(TunnelsChangeNotificationType, false)
   512  }
   513  
   514  func IPCServerNotifyUpdateFound(state UpdateState) {
   515  	notifyAll(UpdateFoundNotificationType, false, state)
   516  }
   517  
   518  func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
   519  	notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
   520  }
   521  
   522  func IPCServerNotifyManagerStopping() {
   523  	notifyAll(ManagerStoppingNotificationType, false)
   524  	time.Sleep(time.Millisecond * 200)
   525  }