github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/updater/process/process_test.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package process
     5  
     6  import (
     7  	"fmt"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"runtime"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/keybase/client/go/updater/util"
    16  	"github.com/keybase/go-logging"
    17  	"github.com/keybase/go-ps"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  var testLog = &logging.Logger{Module: "test"}
    23  
    24  var matchAll = func(p ps.Process) bool { return true }
    25  
    26  func cleanupProc(cmd *exec.Cmd, procPath string) {
    27  	if cmd != nil && cmd.Process != nil {
    28  		_ = cmd.Process.Kill()
    29  	}
    30  	if procPath != "" {
    31  		_ = os.Remove(procPath)
    32  	}
    33  }
    34  
    35  func procTestPath(name string) (string, string) {
    36  	// Copy test executable to tmp
    37  	if runtime.GOOS == "windows" {
    38  		return filepath.Join(os.Getenv("GOPATH"), "bin", "test.exe"), filepath.Join(os.TempDir(), name+".exe")
    39  	}
    40  	return filepath.Join(os.Getenv("GOPATH"), "bin", "test"), filepath.Join(os.TempDir(), name)
    41  }
    42  
    43  func procPath(t *testing.T, name string) string {
    44  	// Copy test executable to tmp
    45  	srcPath, destPath := procTestPath(name)
    46  	err := util.CopyFile(srcPath, destPath, testLog)
    47  	require.NoError(t, err)
    48  	err = os.Chmod(destPath, 0777)
    49  	require.NoError(t, err)
    50  	// Temp dir might have symlinks in which case we need the eval'ed path
    51  	destPath, err = filepath.EvalSymlinks(destPath)
    52  	require.NoError(t, err)
    53  	return destPath
    54  }
    55  
    56  func TestFindPIDsWithFn(t *testing.T) {
    57  	pids, err := findPIDsWithFn(ps.Processes, matchAll, testLog)
    58  	assert.NoError(t, err)
    59  	assert.True(t, len(pids) > 1)
    60  
    61  	fn := func() ([]ps.Process, error) {
    62  		return nil, fmt.Errorf("Testing error")
    63  	}
    64  	processes, err := findPIDsWithFn(fn, matchAll, testLog)
    65  	assert.Nil(t, processes)
    66  	assert.Error(t, err)
    67  
    68  	fn = func() ([]ps.Process, error) {
    69  		return nil, nil
    70  	}
    71  	processes, err = findPIDsWithFn(fn, matchAll, testLog)
    72  	assert.Equal(t, []int{}, processes)
    73  	assert.NoError(t, err)
    74  }
    75  
    76  func TestTerminatePID(t *testing.T) {
    77  	procPath := procPath(t, "testTerminatePID")
    78  	cmd := exec.Command(procPath, "sleep")
    79  	err := cmd.Start()
    80  	defer cleanupProc(cmd, procPath)
    81  	require.NoError(t, err)
    82  	require.NotNil(t, cmd.Process)
    83  
    84  	err = TerminatePID(cmd.Process.Pid, time.Millisecond, testLog)
    85  	assert.NoError(t, err)
    86  }
    87  
    88  func assertTerminated(t *testing.T, pid int, stateStr string) {
    89  	process, err := os.FindProcess(pid)
    90  	require.NoError(t, err)
    91  	state, err := process.Wait()
    92  	require.NoError(t, err)
    93  	assert.Equal(t, stateStr, state.String())
    94  }
    95  
    96  func TestTerminatePIDInvalid(t *testing.T) {
    97  	err := TerminatePID(-5, time.Millisecond, testLog)
    98  	assert.Error(t, err)
    99  }
   100  
   101  func TestTerminateAllFn(t *testing.T) {
   102  	fn := func() ([]ps.Process, error) {
   103  		return nil, fmt.Errorf("Testing error")
   104  	}
   105  	TerminateAllWithProcessesFn(fn, matchAll, time.Millisecond, testLog)
   106  
   107  	fn = func() ([]ps.Process, error) {
   108  		return nil, nil
   109  	}
   110  	TerminateAllWithProcessesFn(fn, matchAll, time.Millisecond, testLog)
   111  }
   112  
   113  func startProcess(t *testing.T, path string, testCommand string) (string, int, *exec.Cmd) {
   114  	cmd := exec.Command(path, testCommand)
   115  	err := cmd.Start()
   116  	require.NoError(t, err)
   117  	require.NotNil(t, cmd.Process)
   118  	return path, cmd.Process.Pid, cmd
   119  }
   120  
   121  func TestTerminateAllPathEqual(t *testing.T) {
   122  	if runtime.GOOS == "windows" {
   123  		t.Skip("flakey :(")
   124  	}
   125  	procPath := procPath(t, "testTerminateAllPathEqual")
   126  	defer util.RemoveFileAtPath(procPath)
   127  	matcher := NewMatcher(procPath, PathEqual, testLog)
   128  	testTerminateAll(t, procPath, matcher, 2)
   129  }
   130  
   131  func TestTerminateAllExecutableEqual(t *testing.T) {
   132  	procPath := procPath(t, "testTerminateAllExecutableEqual")
   133  	defer util.RemoveFileAtPath(procPath)
   134  	matcher := NewMatcher(filepath.Base(procPath), ExecutableEqual, testLog)
   135  	testTerminateAll(t, procPath, matcher, 2)
   136  }
   137  
   138  func TestTerminateAllPathContains(t *testing.T) {
   139  	procPath := procPath(t, "testTerminateAllPathContains")
   140  	defer util.RemoveFileAtPath(procPath)
   141  	procDir, procFile := filepath.Split(procPath)
   142  	match := procDir[1:] + procFile[:20]
   143  	t.Logf("Match: %q", match)
   144  	matcher := NewMatcher(match, PathContains, testLog)
   145  	testTerminateAll(t, procPath, matcher, 2)
   146  }
   147  
   148  func TestTerminateAllPathPrefix(t *testing.T) {
   149  	procPath := procPath(t, "testTerminateAllPathPrefix")
   150  	defer util.RemoveFileAtPath(procPath)
   151  	procDir, procFile := filepath.Split(procPath)
   152  	match := procDir + procFile[:20]
   153  	t.Logf("Match: %q", match)
   154  	matcher := NewMatcher(match, PathPrefix, testLog)
   155  	testTerminateAll(t, procPath, matcher, 2)
   156  }
   157  
   158  func testTerminateAll(t *testing.T, path string, matcher Matcher, numProcs int) {
   159  	var exitStatus string
   160  	if runtime.GOOS == "windows" {
   161  		exitStatus = "exit status 1"
   162  	} else {
   163  		exitStatus = "signal: terminated"
   164  	}
   165  
   166  	pids := []int{}
   167  	for i := 0; i < numProcs; i++ {
   168  		procPath, pid, cmd := startProcess(t, path, "sleep")
   169  		t.Logf("Started process %q (%d)", procPath, pid)
   170  		pids = append(pids, pid)
   171  		defer cleanupProc(cmd, "")
   172  	}
   173  
   174  	time.Sleep(time.Second)
   175  
   176  	terminatePids := TerminateAll(matcher, time.Second, testLog)
   177  	for _, p := range pids {
   178  		assert.Contains(t, terminatePids, p)
   179  		assertTerminated(t, p, exitStatus)
   180  	}
   181  }
   182  
   183  func TestFindProcessWait(t *testing.T) {
   184  	if runtime.GOOS == "windows" {
   185  		t.Skip("Skipping on windows")
   186  	}
   187  	procPath := procPath(t, "testFindProcessWait")
   188  	cmd := exec.Command(procPath, "sleep")
   189  	defer cleanupProc(cmd, procPath)
   190  
   191  	// Ensure it's not already running
   192  	procs, err := FindProcesses(NewMatcher(procPath, PathEqual, testLog), time.Millisecond, 0, testLog)
   193  	require.NoError(t, err)
   194  	require.Equal(t, 0, len(procs))
   195  
   196  	go func() {
   197  		time.Sleep(10 * time.Millisecond)
   198  		startErr := cmd.Start()
   199  		require.NoError(t, startErr)
   200  	}()
   201  
   202  	// Wait up to second for process to be running
   203  	procs, err = FindProcesses(NewMatcher(procPath, PathEqual, testLog), time.Second, 10*time.Millisecond, testLog)
   204  	require.NoError(t, err)
   205  	require.True(t, len(procs) == 1)
   206  }