golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/main.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package main
     7  
     8  import (
     9  	"debug/pe"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"os"
    15  	"strconv"
    16  	"strings"
    17  	"time"
    18  
    19  	"golang.org/x/sys/windows"
    20  
    21  	"golang.zx2c4.com/wireguard/windows/conf"
    22  	"golang.zx2c4.com/wireguard/windows/driver"
    23  	"golang.zx2c4.com/wireguard/windows/elevate"
    24  	"golang.zx2c4.com/wireguard/windows/l18n"
    25  	"golang.zx2c4.com/wireguard/windows/manager"
    26  	"golang.zx2c4.com/wireguard/windows/ringlogger"
    27  	"golang.zx2c4.com/wireguard/windows/tunnel"
    28  	"golang.zx2c4.com/wireguard/windows/ui"
    29  	"golang.zx2c4.com/wireguard/windows/updater"
    30  )
    31  
    32  func setLogFile() {
    33  	logHandle, err := windows.GetStdHandle(windows.STD_ERROR_HANDLE)
    34  	if logHandle == 0 || err != nil {
    35  		logHandle, err = windows.GetStdHandle(windows.STD_OUTPUT_HANDLE)
    36  	}
    37  	if logHandle == 0 || err != nil {
    38  		log.SetOutput(io.Discard)
    39  	} else {
    40  		log.SetOutput(os.NewFile(uintptr(logHandle), "stderr"))
    41  	}
    42  }
    43  
    44  func fatal(v ...any) {
    45  	if log.Writer() == io.Discard {
    46  		windows.MessageBox(0, windows.StringToUTF16Ptr(fmt.Sprint(v...)), windows.StringToUTF16Ptr(l18n.Sprintf("Error")), windows.MB_ICONERROR)
    47  		os.Exit(1)
    48  	} else {
    49  		log.Fatal(append([]any{l18n.Sprintf("Error: ")}, v...))
    50  	}
    51  }
    52  
    53  func fatalf(format string, v ...any) {
    54  	fatal(l18n.Sprintf(format, v...))
    55  }
    56  
    57  func info(title, format string, v ...any) {
    58  	if log.Writer() == io.Discard {
    59  		windows.MessageBox(0, windows.StringToUTF16Ptr(l18n.Sprintf(format, v...)), windows.StringToUTF16Ptr(title), windows.MB_ICONINFORMATION)
    60  	} else {
    61  		log.Printf(title+":\n"+format, v...)
    62  	}
    63  }
    64  
    65  func usage() {
    66  	flags := [...]string{
    67  		l18n.Sprintf("(no argument): elevate and install manager service"),
    68  		"/installmanagerservice",
    69  		"/installtunnelservice CONFIG_PATH",
    70  		"/uninstallmanagerservice",
    71  		"/uninstalltunnelservice TUNNEL_NAME",
    72  		"/managerservice",
    73  		"/tunnelservice CONFIG_PATH",
    74  		"/ui CMD_READ_HANDLE CMD_WRITE_HANDLE CMD_EVENT_HANDLE LOG_MAPPING_HANDLE",
    75  		"/dumplog [/tail]",
    76  		"/update",
    77  		"/removedriver",
    78  	}
    79  	builder := strings.Builder{}
    80  	for _, flag := range flags {
    81  		builder.WriteString(fmt.Sprintf("    %s\n", flag))
    82  	}
    83  	info(l18n.Sprintf("Command Line Options"), "Usage: %s [\n%s]", os.Args[0], builder.String())
    84  	os.Exit(1)
    85  }
    86  
    87  func checkForWow64() {
    88  	b, err := func() (bool, error) {
    89  		var processMachine, nativeMachine uint16
    90  		err := windows.IsWow64Process2(windows.CurrentProcess(), &processMachine, &nativeMachine)
    91  		if err == nil {
    92  			return processMachine != pe.IMAGE_FILE_MACHINE_UNKNOWN, nil
    93  		}
    94  		if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
    95  			return false, err
    96  		}
    97  		var b bool
    98  		err = windows.IsWow64Process(windows.CurrentProcess(), &b)
    99  		if err != nil {
   100  			return false, err
   101  		}
   102  		return b, nil
   103  	}()
   104  	if err != nil {
   105  		fatalf("Unable to determine whether the process is running under WOW64: %v", err)
   106  	}
   107  	if b {
   108  		fatalf("You must use the native version of WireGuard on this computer.")
   109  	}
   110  }
   111  
   112  func checkForAdminGroup() {
   113  	// This is not a security check, but rather a user-confusion one.
   114  	var processToken windows.Token
   115  	err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY|windows.TOKEN_DUPLICATE, &processToken)
   116  	if err != nil {
   117  		fatalf("Unable to open current process token: %v", err)
   118  	}
   119  	defer processToken.Close()
   120  	if !elevate.TokenIsElevatedOrElevatable(processToken) {
   121  		fatalf("WireGuard may only be used by users who are a member of the Builtin %s group.", elevate.AdminGroupName())
   122  	}
   123  }
   124  
   125  func checkForAdminDesktop() {
   126  	adminDesktop, err := elevate.IsAdminDesktop()
   127  	if !adminDesktop && err == nil {
   128  		fatalf("WireGuard is running, but the UI is only accessible from desktops of the Builtin %s group.", elevate.AdminGroupName())
   129  	}
   130  }
   131  
   132  func execElevatedManagerServiceInstaller() error {
   133  	path, err := os.Executable()
   134  	if err != nil {
   135  		return err
   136  	}
   137  	err = elevate.ShellExecute(path, "/installmanagerservice", "", windows.SW_SHOW)
   138  	if err != nil && err != windows.ERROR_CANCELLED {
   139  		return err
   140  	}
   141  	os.Exit(0)
   142  	return windows.ERROR_UNHANDLED_EXCEPTION // Not reached
   143  }
   144  
   145  func pipeFromHandleArgument(handleStr string) (*os.File, error) {
   146  	handleInt, err := strconv.ParseUint(handleStr, 10, 64)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  	return os.NewFile(uintptr(handleInt), "pipe"), nil
   151  }
   152  
   153  func main() {
   154  	if windows.SetDllDirectory("") != nil || windows.SetDefaultDllDirectories(windows.LOAD_LIBRARY_SEARCH_SYSTEM32) != nil {
   155  		panic("failed to restrict dll search path")
   156  	}
   157  
   158  	setLogFile()
   159  	checkForWow64()
   160  
   161  	if len(os.Args) <= 1 {
   162  		if ui.RaiseUI() {
   163  			return
   164  		}
   165  		checkForAdminGroup()
   166  		err := execElevatedManagerServiceInstaller()
   167  		if err != nil {
   168  			fatal(err)
   169  		}
   170  		return
   171  	}
   172  	switch os.Args[1] {
   173  	case "/installmanagerservice":
   174  		if len(os.Args) != 2 {
   175  			usage()
   176  		}
   177  		go ui.WaitForRaiseUIThenQuit()
   178  		err := manager.InstallManager()
   179  		if err != nil {
   180  			if err == manager.ErrManagerAlreadyRunning {
   181  				checkForAdminDesktop()
   182  			}
   183  			fatal(err)
   184  		}
   185  		checkForAdminDesktop()
   186  		time.Sleep(30 * time.Second)
   187  		fatalf("WireGuard system tray icon did not appear after 30 seconds.")
   188  		return
   189  	case "/uninstallmanagerservice":
   190  		if len(os.Args) != 2 {
   191  			usage()
   192  		}
   193  		err := manager.UninstallManager()
   194  		if err != nil {
   195  			fatal(err)
   196  		}
   197  		return
   198  	case "/managerservice":
   199  		if len(os.Args) != 2 {
   200  			usage()
   201  		}
   202  		err := manager.Run()
   203  		if err != nil {
   204  			fatal(err)
   205  		}
   206  		return
   207  	case "/installtunnelservice":
   208  		if len(os.Args) != 3 {
   209  			usage()
   210  		}
   211  		err := manager.InstallTunnel(os.Args[2])
   212  		if err != nil {
   213  			fatal(err)
   214  		}
   215  		return
   216  	case "/uninstalltunnelservice":
   217  		if len(os.Args) != 3 {
   218  			usage()
   219  		}
   220  		err := manager.UninstallTunnel(os.Args[2])
   221  		if err != nil {
   222  			fatal(err)
   223  		}
   224  		return
   225  	case "/tunnelservice":
   226  		if len(os.Args) != 3 {
   227  			usage()
   228  		}
   229  		err := tunnel.Run(os.Args[2])
   230  		if err != nil {
   231  			fatal(err)
   232  		}
   233  		return
   234  	case "/ui":
   235  		if len(os.Args) != 6 {
   236  			usage()
   237  		}
   238  		var processToken windows.Token
   239  		isAdmin := false
   240  		err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY|windows.TOKEN_DUPLICATE, &processToken)
   241  		if err == nil {
   242  			isAdmin = elevate.TokenIsElevatedOrElevatable(processToken)
   243  			processToken.Close()
   244  		}
   245  		if isAdmin {
   246  			err := elevate.DropAllPrivileges(false)
   247  			if err != nil {
   248  				fatal(err)
   249  			}
   250  		}
   251  		readPipe, err := pipeFromHandleArgument(os.Args[2])
   252  		if err != nil {
   253  			fatal(err)
   254  		}
   255  		writePipe, err := pipeFromHandleArgument(os.Args[3])
   256  		if err != nil {
   257  			fatal(err)
   258  		}
   259  		eventPipe, err := pipeFromHandleArgument(os.Args[4])
   260  		if err != nil {
   261  			fatal(err)
   262  		}
   263  		ringlogger.Global, err = ringlogger.NewRingloggerFromInheritedMappingHandle(os.Args[5], "GUI")
   264  		if err != nil {
   265  			fatal(err)
   266  		}
   267  		manager.InitializeIPCClient(readPipe, writePipe, eventPipe)
   268  		ui.IsAdmin = isAdmin
   269  		ui.RunUI()
   270  		return
   271  	case "/dumplog":
   272  		if len(os.Args) != 2 && len(os.Args) != 3 {
   273  			usage()
   274  		}
   275  		outputHandle, err := windows.GetStdHandle(windows.STD_OUTPUT_HANDLE)
   276  		if err != nil {
   277  			fatal(err)
   278  		}
   279  		if outputHandle == 0 {
   280  			fatal("Stdout must be set")
   281  		}
   282  		file := os.NewFile(uintptr(outputHandle), "stdout")
   283  		defer file.Close()
   284  		logPath, err := conf.LogFile(false)
   285  		if err != nil {
   286  			fatal(err)
   287  		}
   288  		err = ringlogger.DumpTo(logPath, file, len(os.Args) == 3 && os.Args[2] == "/tail")
   289  		if err != nil {
   290  			fatal(err)
   291  		}
   292  		return
   293  	case "/update":
   294  		if len(os.Args) != 2 {
   295  			usage()
   296  		}
   297  		for progress := range updater.DownloadVerifyAndExecute(0) {
   298  			if len(progress.Activity) > 0 {
   299  				if progress.BytesTotal > 0 || progress.BytesDownloaded > 0 {
   300  					var percent float64
   301  					if progress.BytesTotal > 0 {
   302  						percent = float64(progress.BytesDownloaded) / float64(progress.BytesTotal) * 100.0
   303  					}
   304  					log.Printf("%s: %d/%d (%.2f%%)\n", progress.Activity, progress.BytesDownloaded, progress.BytesTotal, percent)
   305  				} else {
   306  					log.Println(progress.Activity)
   307  				}
   308  			}
   309  			if progress.Error != nil {
   310  				log.Printf("Error: %v\n", progress.Error)
   311  			}
   312  			if progress.Complete || progress.Error != nil {
   313  				return
   314  			}
   315  		}
   316  		return
   317  	case "/removedriver":
   318  		if len(os.Args) != 2 {
   319  			usage()
   320  		}
   321  		_ = driver.UninstallLegacyWintun() // Best effort
   322  		err := driver.Uninstall()
   323  		if err != nil {
   324  			fatal(err)
   325  		}
   326  		return
   327  	}
   328  	usage()
   329  }