golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/service.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package manager
     7  
     8  import (
     9  	"errors"
    10  	"log"
    11  	"os"
    12  	"runtime"
    13  	"strconv"
    14  	"sync"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/windows"
    19  	"golang.org/x/sys/windows/svc"
    20  	"golang.zx2c4.com/wireguard/windows/driver"
    21  
    22  	"golang.zx2c4.com/wireguard/windows/conf"
    23  	"golang.zx2c4.com/wireguard/windows/elevate"
    24  	"golang.zx2c4.com/wireguard/windows/ringlogger"
    25  	"golang.zx2c4.com/wireguard/windows/services"
    26  )
    27  
    28  type managerService struct{}
    29  
    30  func (service *managerService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
    31  	changes <- svc.Status{State: svc.StartPending}
    32  
    33  	var err error
    34  	serviceError := services.ErrorSuccess
    35  
    36  	defer func() {
    37  		svcSpecificEC, exitCode = services.DetermineErrorCode(err, serviceError)
    38  		logErr := services.CombineErrors(err, serviceError)
    39  		if logErr != nil {
    40  			log.Print(logErr)
    41  		}
    42  		changes <- svc.Status{State: svc.StopPending}
    43  	}()
    44  
    45  	var logFile string
    46  	logFile, err = conf.LogFile(true)
    47  	if err != nil {
    48  		serviceError = services.ErrorRingloggerOpen
    49  		return
    50  	}
    51  	err = ringlogger.InitGlobalLogger(logFile, "MGR")
    52  	if err != nil {
    53  		serviceError = services.ErrorRingloggerOpen
    54  		return
    55  	}
    56  
    57  	services.PrintStarting()
    58  
    59  	path, err := os.Executable()
    60  	if err != nil {
    61  		serviceError = services.ErrorDetermineExecutablePath
    62  		return
    63  	}
    64  
    65  	err = watchNewTunnelServices()
    66  	if err != nil {
    67  		serviceError = services.ErrorTrackTunnels
    68  		return
    69  	}
    70  
    71  	conf.RegisterStoreChangeCallback(func() { conf.MigrateUnencryptedConfigs(changeTunnelServiceConfigFilePath) })
    72  	conf.RegisterStoreChangeCallback(IPCServerNotifyTunnelsChange)
    73  
    74  	procs := make(map[uint32]*uiProcess)
    75  	aliveSessions := make(map[uint32]bool)
    76  	procsLock := sync.Mutex{}
    77  	stoppingManager := false
    78  	operatorGroupSid, _ := windows.CreateWellKnownSid(windows.WinBuiltinNetworkConfigurationOperatorsSid)
    79  
    80  	startProcess := func(session uint32) {
    81  		defer func() {
    82  			runtime.UnlockOSThread()
    83  			procsLock.Lock()
    84  			delete(aliveSessions, session)
    85  			procsLock.Unlock()
    86  		}()
    87  
    88  		var userToken windows.Token
    89  		err := windows.WTSQueryUserToken(session, &userToken)
    90  		if err != nil {
    91  			return
    92  		}
    93  		isAdmin := elevate.TokenIsElevatedOrElevatable(userToken)
    94  		isOperator := false
    95  		if !isAdmin && conf.AdminBool("LimitedOperatorUI") && operatorGroupSid != nil {
    96  			linkedToken, err := userToken.GetLinkedToken()
    97  			var impersonationToken windows.Token
    98  			if err == nil {
    99  				err = windows.DuplicateTokenEx(linkedToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken)
   100  				linkedToken.Close()
   101  			} else {
   102  				err = windows.DuplicateTokenEx(userToken, windows.TOKEN_QUERY, nil, windows.SecurityImpersonation, windows.TokenImpersonation, &impersonationToken)
   103  			}
   104  			if err == nil {
   105  				isOperator, err = impersonationToken.IsMember(operatorGroupSid)
   106  				isOperator = isOperator && err == nil
   107  				impersonationToken.Close()
   108  			}
   109  		}
   110  		if !isAdmin && !isOperator {
   111  			userToken.Close()
   112  			return
   113  		}
   114  		user, err := userToken.GetTokenUser()
   115  		if err != nil {
   116  			log.Printf("Unable to lookup user from token: %v", err)
   117  			userToken.Close()
   118  			return
   119  		}
   120  		username, domain, accType, err := user.User.Sid.LookupAccount("")
   121  		if err != nil {
   122  			log.Printf("Unable to lookup username from sid: %v", err)
   123  			userToken.Close()
   124  			return
   125  		}
   126  		if accType != windows.SidTypeUser {
   127  			userToken.Close()
   128  			return
   129  		}
   130  		userProfileDirectory, _ := userToken.GetUserProfileDirectory()
   131  		var elevatedToken, runToken windows.Token
   132  		if isAdmin {
   133  			if userToken.IsElevated() {
   134  				elevatedToken = userToken
   135  			} else {
   136  				elevatedToken, err = userToken.GetLinkedToken()
   137  				userToken.Close()
   138  				if err != nil {
   139  					log.Printf("Unable to elevate token: %v", err)
   140  					return
   141  				}
   142  				if !elevatedToken.IsElevated() {
   143  					elevatedToken.Close()
   144  					log.Println("Linked token is not elevated")
   145  					return
   146  				}
   147  			}
   148  			runToken = elevatedToken
   149  		} else {
   150  			runToken = userToken
   151  		}
   152  		defer runToken.Close()
   153  		userToken = 0
   154  		first := true
   155  		for {
   156  			if stoppingManager {
   157  				return
   158  			}
   159  
   160  			procsLock.Lock()
   161  			if alive := aliveSessions[session]; !alive {
   162  				procsLock.Unlock()
   163  				return
   164  			}
   165  			procsLock.Unlock()
   166  
   167  			if !first {
   168  				time.Sleep(time.Second)
   169  			} else {
   170  				first = false
   171  			}
   172  
   173  			ourReader, theirWriter, err := os.Pipe()
   174  			if err != nil {
   175  				log.Printf("Unable to create pipe: %v", err)
   176  				return
   177  			}
   178  			theirReader, ourWriter, err := os.Pipe()
   179  			if err != nil {
   180  				log.Printf("Unable to create pipe: %v", err)
   181  				return
   182  			}
   183  			theirEvents, ourEvents, err := os.Pipe()
   184  			if err != nil {
   185  				log.Printf("Unable to create pipe: %v", err)
   186  				return
   187  			}
   188  			IPCServerListen(ourReader, ourWriter, ourEvents, elevatedToken)
   189  			theirLogMapping, err := ringlogger.Global.ExportInheritableMappingHandle()
   190  			if err != nil {
   191  				log.Printf("Unable to export inheritable mapping handle for logging: %v", err)
   192  				return
   193  			}
   194  
   195  			log.Printf("Starting UI process for user ā€˜%s@%sā€™ for session %d", username, domain, session)
   196  			procsLock.Lock()
   197  			var proc *uiProcess
   198  			if alive := aliveSessions[session]; alive {
   199  				proc, err = launchUIProcess(path, []string{
   200  					path,
   201  					"/ui",
   202  					strconv.FormatUint(uint64(theirReader.Fd()), 10),
   203  					strconv.FormatUint(uint64(theirWriter.Fd()), 10),
   204  					strconv.FormatUint(uint64(theirEvents.Fd()), 10),
   205  					strconv.FormatUint(uint64(theirLogMapping), 10),
   206  				}, userProfileDirectory, []windows.Handle{
   207  					windows.Handle(theirReader.Fd()),
   208  					windows.Handle(theirWriter.Fd()),
   209  					windows.Handle(theirEvents.Fd()),
   210  					theirLogMapping,
   211  				}, runToken)
   212  			} else {
   213  				err = errors.New("Session has logged out")
   214  			}
   215  			procsLock.Unlock()
   216  			theirReader.Close()
   217  			theirWriter.Close()
   218  			theirEvents.Close()
   219  			windows.CloseHandle(theirLogMapping)
   220  			if err != nil {
   221  				ourReader.Close()
   222  				ourWriter.Close()
   223  				ourEvents.Close()
   224  				log.Printf("Unable to start manager UI process for user '%s@%s' for session %d: %v", username, domain, session, err)
   225  				return
   226  			}
   227  
   228  			procsLock.Lock()
   229  			procs[session] = proc
   230  			procsLock.Unlock()
   231  
   232  			sessionIsDead := false
   233  			if exitCode, err := proc.Wait(); err == nil {
   234  				log.Printf("Exited UI process for user '%s@%s' for session %d with status %x", username, domain, session, exitCode)
   235  				const STATUS_DLL_INIT_FAILED_LOGOFF = 0xC000026B
   236  				sessionIsDead = exitCode == STATUS_DLL_INIT_FAILED_LOGOFF
   237  			} else {
   238  				log.Printf("Unable to wait for UI process for user '%s@%s' for session %d: %v", username, domain, session, err)
   239  			}
   240  
   241  			procsLock.Lock()
   242  			delete(procs, session)
   243  			procsLock.Unlock()
   244  			ourReader.Close()
   245  			ourWriter.Close()
   246  			ourEvents.Close()
   247  
   248  			if sessionIsDead {
   249  				return
   250  			}
   251  		}
   252  	}
   253  	procsGroup := sync.WaitGroup{}
   254  	goStartProcess := func(session uint32) {
   255  		procsGroup.Add(1)
   256  		go func() {
   257  			startProcess(session)
   258  			procsGroup.Done()
   259  		}()
   260  	}
   261  
   262  	go checkForUpdates()
   263  	go driver.UninstallLegacyWintun() // We uninstall opportunistically here, so that we don't have to carry around the uninstaller code forever.
   264  
   265  	var sessionsPointer *windows.WTS_SESSION_INFO
   266  	var count uint32
   267  	err = windows.WTSEnumerateSessions(0, 0, 1, &sessionsPointer, &count)
   268  	if err != nil {
   269  		serviceError = services.ErrorEnumerateSessions
   270  		return
   271  	}
   272  	for _, session := range unsafe.Slice(sessionsPointer, count) {
   273  		if session.State != windows.WTSActive && session.State != windows.WTSDisconnected {
   274  			continue
   275  		}
   276  		procsLock.Lock()
   277  		if alive := aliveSessions[session.SessionID]; !alive {
   278  			aliveSessions[session.SessionID] = true
   279  			if _, ok := procs[session.SessionID]; !ok {
   280  				goStartProcess(session.SessionID)
   281  			}
   282  		}
   283  		procsLock.Unlock()
   284  	}
   285  	windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionsPointer)))
   286  
   287  	changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptSessionChange}
   288  
   289  	uninstall := false
   290  loop:
   291  	for {
   292  		select {
   293  		case <-quitManagersChan:
   294  			uninstall = true
   295  			break loop
   296  		case c := <-r:
   297  			switch c.Cmd {
   298  			case svc.Stop:
   299  				break loop
   300  			case svc.Interrogate:
   301  				changes <- c.CurrentStatus
   302  			case svc.SessionChange:
   303  				if c.EventType != windows.WTS_SESSION_LOGON && c.EventType != windows.WTS_SESSION_LOGOFF {
   304  					continue
   305  				}
   306  				sessionNotification := (*windows.WTSSESSION_NOTIFICATION)(unsafe.Pointer(c.EventData))
   307  				if uintptr(sessionNotification.Size) != unsafe.Sizeof(*sessionNotification) {
   308  					log.Printf("Unexpected size of WTSSESSION_NOTIFICATION: %d", sessionNotification.Size)
   309  					continue
   310  				}
   311  				if c.EventType == windows.WTS_SESSION_LOGOFF {
   312  					procsLock.Lock()
   313  					delete(aliveSessions, sessionNotification.SessionID)
   314  					if proc, ok := procs[sessionNotification.SessionID]; ok {
   315  						proc.Kill()
   316  					}
   317  					procsLock.Unlock()
   318  				} else if c.EventType == windows.WTS_SESSION_LOGON {
   319  					procsLock.Lock()
   320  					if alive := aliveSessions[sessionNotification.SessionID]; !alive {
   321  						aliveSessions[sessionNotification.SessionID] = true
   322  						if _, ok := procs[sessionNotification.SessionID]; !ok {
   323  							goStartProcess(sessionNotification.SessionID)
   324  						}
   325  					}
   326  					procsLock.Unlock()
   327  				}
   328  
   329  			default:
   330  				log.Printf("Unexpected service control request #%d", c)
   331  			}
   332  		}
   333  	}
   334  
   335  	changes <- svc.Status{State: svc.StopPending}
   336  	procsLock.Lock()
   337  	stoppingManager = true
   338  	IPCServerNotifyManagerStopping()
   339  	for _, proc := range procs {
   340  		proc.Kill()
   341  	}
   342  	procsLock.Unlock()
   343  	procsGroup.Wait()
   344  	if uninstall {
   345  		err = UninstallManager()
   346  		if err != nil {
   347  			log.Printf("Unable to uninstall manager when quitting: %v", err)
   348  		}
   349  	}
   350  	return
   351  }
   352  
   353  func Run() error {
   354  	return svc.Run("WireGuardManager", &managerService{})
   355  }