github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/proc/exec_windows.go (about)

     1  package proc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"os/exec" //nolint:depguard // We want no logging and no soft-context signal handling
     9  	"unsafe"
    10  
    11  	"golang.org/x/sys/windows"
    12  
    13  	"github.com/datawire/dlib/dexec"
    14  	"github.com/datawire/dlib/dlog"
    15  	"github.com/telepresenceio/telepresence/v2/pkg/shellquote"
    16  )
    17  
    18  var SignalsToForward = []os.Signal{os.Interrupt} //nolint:gochecknoglobals // OS-specific constant list
    19  
    20  // SIGTERM uses os.Interrupt on Windows as a best effort.
    21  var SIGTERM = os.Interrupt //nolint:gochecknoglobals // OS-specific constant
    22  
    23  func CommandContext(ctx context.Context, name string, args ...string) *dexec.Cmd {
    24  	cmd := dexec.CommandContext(ctx, name, args...)
    25  	createNewProcessGroup(cmd.Cmd)
    26  	return cmd
    27  }
    28  
    29  func createNewProcessGroup(cmd *exec.Cmd) {
    30  	cmd.SysProcAttr = &windows.SysProcAttr{CreationFlags: windows.CREATE_NEW_PROCESS_GROUP}
    31  }
    32  
    33  func cacheAdmin(_ context.Context, _ string) error {
    34  	// No-op on windows, there's no sudo caching. Runas will just pop a window open.
    35  	return nil
    36  }
    37  
    38  func startInBackground(_ bool, args ...string) error {
    39  	return shellExec("open", args[0], args[1:]...)
    40  }
    41  
    42  func startInBackgroundAsRoot(_ context.Context, args ...string) error {
    43  	verb := "runas"
    44  	if isAdmin() {
    45  		verb = "open"
    46  	}
    47  	return shellExec(verb, args[0], args[1:]...)
    48  }
    49  
    50  func shellExec(verb, exe string, args ...string) error {
    51  	cwd, _ := os.Getwd()
    52  	// UTF16PtrFromString can only fail if the argument contains a NUL byte. That will never happen here.
    53  	verbPtr, _ := windows.UTF16PtrFromString(verb)
    54  	exePtr, _ := windows.UTF16PtrFromString(exe)
    55  	cwdPtr, _ := windows.UTF16PtrFromString(cwd)
    56  	var argPtr *uint16
    57  	if len(args) > 0 {
    58  		argsStr := shellquote.ShellArgsString(args)
    59  		argPtr, _ = windows.UTF16PtrFromString(argsStr)
    60  	}
    61  	return windows.ShellExecute(0, verbPtr, exePtr, argPtr, cwdPtr, windows.SW_HIDE)
    62  }
    63  
    64  func isAdmin() bool {
    65  	// Directly copied from the official Windows documentation. The Go API for this is a
    66  	// direct wrap around the official C++ API.
    67  	// See https://docs.microsoft.com/en-us/windows/desktop/api/securitybaseapi/nf-securitybaseapi-checktokenmembership
    68  	var sid *windows.SID
    69  	err := windows.AllocateAndInitializeSid(
    70  		&windows.SECURITY_NT_AUTHORITY,
    71  		2,
    72  		windows.SECURITY_BUILTIN_DOMAIN_RID,
    73  		windows.DOMAIN_ALIAS_RID_ADMINS,
    74  		0, 0, 0, 0, 0, 0,
    75  		&sid)
    76  	if err != nil {
    77  		return false
    78  	}
    79  	adm, err := windows.GetCurrentProcessToken().IsMember(sid)
    80  	return err == nil && adm
    81  }
    82  
    83  func terminate(p *os.Process) error {
    84  	return p.Kill()
    85  }
    86  
    87  const peSize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
    88  
    89  type processInfo struct {
    90  	pid  uint32
    91  	ppid uint32
    92  	exe  string
    93  }
    94  
    95  func killProcessGroup(ctx context.Context, cmd *exec.Cmd, sig os.Signal) {
    96  	pes := make([]*processInfo, 0, 100)
    97  	err := eachProcess(func(pe *windows.ProcessEntry32) bool {
    98  		pes = append(pes, &processInfo{
    99  			pid:  pe.ProcessID,
   100  			ppid: pe.ParentProcessID,
   101  			exe:  windows.UTF16ToString(pe.ExeFile[:]),
   102  		})
   103  		return true
   104  	})
   105  	if err != nil {
   106  		dlog.Error(ctx, err)
   107  	} else if err = terminateProcess(ctx, cmd.Path, uint32(cmd.Process.Pid), sig, pes); err != nil {
   108  		dlog.Error(ctx, err)
   109  	}
   110  }
   111  
   112  // terminateProcess will terminate the given process and all its children. The
   113  // children are terminated first.
   114  func terminateProcess(ctx context.Context, exe string, pid uint32, sig os.Signal, pes []*processInfo) error {
   115  	if err := terminateChildrenOf(ctx, pid, sig, pes); err != nil {
   116  		return err
   117  	}
   118  
   119  	if sig == os.Interrupt {
   120  		if err := windows.GenerateConsoleCtrlEvent(windows.CTRL_BREAK_EVENT, pid); err != nil {
   121  			// An ACCESS_DENIED error may indicate that the process is dead already but
   122  			// died just after the handle to it was opened.
   123  			if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
   124  				if alive, aliveErr := processIsAlive(pid); aliveErr != nil {
   125  					dlog.Error(ctx, aliveErr)
   126  				} else if !alive {
   127  					return nil
   128  				}
   129  			}
   130  			return fmt.Errorf("%q: %w", exe, &os.SyscallError{Syscall: "GenerateConsoleCtrlEvent", Err: err})
   131  		}
   132  		dlog.Debugf(ctx, "sent ctrl-c to process %q (pid %d)", exe, pid)
   133  		return nil
   134  	}
   135  
   136  	// SYNCHRONIZE is required to wait for the process to terminate
   137  	h, err := windows.OpenProcess(windows.SYNCHRONIZE|windows.PROCESS_TERMINATE, true, pid)
   138  	if err != nil {
   139  		if errors.Is(err, windows.ERROR_INVALID_PARAMETER) {
   140  			// ERROR_INVALID_PARAMETER means that the process no longer exists. It might
   141  			// have died because we killed its children.
   142  			return nil
   143  		}
   144  		return fmt.Errorf("failed to open handle of %q: %w", exe, err)
   145  	}
   146  	defer func() {
   147  		_ = windows.CloseHandle(h)
   148  	}()
   149  
   150  	if err = windows.TerminateProcess(h, 0); err != nil {
   151  		// An ACCESS_DENIED error may indicate that the process is dead already but
   152  		// died just after the handle to it was opened.
   153  		if errors.Is(err, windows.ERROR_ACCESS_DENIED) {
   154  			if alive, aliveErr := processIsAlive(pid); aliveErr != nil {
   155  				dlog.Error(ctx, aliveErr)
   156  			} else if !alive {
   157  				return nil
   158  			}
   159  		}
   160  		return fmt.Errorf("%q: %w", exe, &os.SyscallError{Syscall: "TerminateProcess", Err: err})
   161  	}
   162  	dlog.Debugf(ctx, "terminated process %q (pid %d)", exe, pid)
   163  	return nil
   164  }
   165  
   166  func terminateChildrenOf(ctx context.Context, pid uint32, sig os.Signal, pes []*processInfo) error {
   167  	for _, pe := range pes {
   168  		if pe.ppid == pid {
   169  			if err := terminateProcess(ctx, pe.exe, pe.pid, sig, pes); err != nil {
   170  				return err
   171  			}
   172  		}
   173  	}
   174  	return nil
   175  }
   176  
   177  // processIsAlive checks if the given pid exists in the current process snapshot.
   178  func processIsAlive(pid uint32) (bool, error) {
   179  	found := false
   180  	err := eachProcess(func(pe *windows.ProcessEntry32) bool {
   181  		if pe.ProcessID == pid {
   182  			found = true
   183  			return false // break iteration
   184  		}
   185  		return true
   186  	})
   187  	if err != nil {
   188  		return false, err
   189  	}
   190  	return found, nil
   191  }
   192  
   193  // eachProcess calls the given function with each ProcessEntry32 found
   194  // in the current process snapshot. The iteration ends if the given function
   195  // returns false.
   196  func eachProcess(f func(pe *windows.ProcessEntry32) bool) error {
   197  	h, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
   198  	if err != nil {
   199  		return fmt.Errorf("unable to get process snapshot: %w", err)
   200  	}
   201  	defer func() {
   202  		_ = windows.CloseHandle(h)
   203  	}()
   204  	pe := new(windows.ProcessEntry32)
   205  	pe.Size = peSize
   206  	err = windows.Process32First(h, pe)
   207  	for err == nil {
   208  		if !f(pe) {
   209  			break
   210  		}
   211  		err = windows.Process32Next(h, pe)
   212  	}
   213  	return nil
   214  }