github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/updater/process/process_test.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package process 5 6 import ( 7 "fmt" 8 "os" 9 "os/exec" 10 "path/filepath" 11 "runtime" 12 "testing" 13 "time" 14 15 "github.com/keybase/client/go/updater/util" 16 "github.com/keybase/go-logging" 17 "github.com/keybase/go-ps" 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 ) 21 22 var testLog = &logging.Logger{Module: "test"} 23 24 var matchAll = func(p ps.Process) bool { return true } 25 26 func cleanupProc(cmd *exec.Cmd, procPath string) { 27 if cmd != nil && cmd.Process != nil { 28 _ = cmd.Process.Kill() 29 } 30 if procPath != "" { 31 _ = os.Remove(procPath) 32 } 33 } 34 35 func procTestPath(name string) (string, string) { 36 // Copy test executable to tmp 37 if runtime.GOOS == "windows" { 38 return filepath.Join(os.Getenv("GOPATH"), "bin", "test.exe"), filepath.Join(os.TempDir(), name+".exe") 39 } 40 return filepath.Join(os.Getenv("GOPATH"), "bin", "test"), filepath.Join(os.TempDir(), name) 41 } 42 43 func procPath(t *testing.T, name string) string { 44 // Copy test executable to tmp 45 srcPath, destPath := procTestPath(name) 46 err := util.CopyFile(srcPath, destPath, testLog) 47 require.NoError(t, err) 48 err = os.Chmod(destPath, 0777) 49 require.NoError(t, err) 50 // Temp dir might have symlinks in which case we need the eval'ed path 51 destPath, err = filepath.EvalSymlinks(destPath) 52 require.NoError(t, err) 53 return destPath 54 } 55 56 func TestFindPIDsWithFn(t *testing.T) { 57 pids, err := findPIDsWithFn(ps.Processes, matchAll, testLog) 58 assert.NoError(t, err) 59 assert.True(t, len(pids) > 1) 60 61 fn := func() ([]ps.Process, error) { 62 return nil, fmt.Errorf("Testing error") 63 } 64 processes, err := findPIDsWithFn(fn, matchAll, testLog) 65 assert.Nil(t, processes) 66 assert.Error(t, err) 67 68 fn = func() ([]ps.Process, error) { 69 return nil, nil 70 } 71 processes, err = findPIDsWithFn(fn, matchAll, testLog) 72 assert.Equal(t, []int{}, processes) 73 assert.NoError(t, err) 74 } 75 76 func TestTerminatePID(t *testing.T) { 77 procPath := procPath(t, "testTerminatePID") 78 cmd := exec.Command(procPath, "sleep") 79 err := cmd.Start() 80 defer cleanupProc(cmd, procPath) 81 require.NoError(t, err) 82 require.NotNil(t, cmd.Process) 83 84 err = TerminatePID(cmd.Process.Pid, time.Millisecond, testLog) 85 assert.NoError(t, err) 86 } 87 88 func assertTerminated(t *testing.T, pid int, stateStr string) { 89 process, err := os.FindProcess(pid) 90 require.NoError(t, err) 91 state, err := process.Wait() 92 require.NoError(t, err) 93 assert.Equal(t, stateStr, state.String()) 94 } 95 96 func TestTerminatePIDInvalid(t *testing.T) { 97 err := TerminatePID(-5, time.Millisecond, testLog) 98 assert.Error(t, err) 99 } 100 101 func TestTerminateAllFn(t *testing.T) { 102 fn := func() ([]ps.Process, error) { 103 return nil, fmt.Errorf("Testing error") 104 } 105 TerminateAllWithProcessesFn(fn, matchAll, time.Millisecond, testLog) 106 107 fn = func() ([]ps.Process, error) { 108 return nil, nil 109 } 110 TerminateAllWithProcessesFn(fn, matchAll, time.Millisecond, testLog) 111 } 112 113 func startProcess(t *testing.T, path string, testCommand string) (string, int, *exec.Cmd) { 114 cmd := exec.Command(path, testCommand) 115 err := cmd.Start() 116 require.NoError(t, err) 117 require.NotNil(t, cmd.Process) 118 return path, cmd.Process.Pid, cmd 119 } 120 121 func TestTerminateAllPathEqual(t *testing.T) { 122 if runtime.GOOS == "windows" { 123 t.Skip("flakey :(") 124 } 125 procPath := procPath(t, "testTerminateAllPathEqual") 126 defer util.RemoveFileAtPath(procPath) 127 matcher := NewMatcher(procPath, PathEqual, testLog) 128 testTerminateAll(t, procPath, matcher, 2) 129 } 130 131 func TestTerminateAllExecutableEqual(t *testing.T) { 132 procPath := procPath(t, "testTerminateAllExecutableEqual") 133 defer util.RemoveFileAtPath(procPath) 134 matcher := NewMatcher(filepath.Base(procPath), ExecutableEqual, testLog) 135 testTerminateAll(t, procPath, matcher, 2) 136 } 137 138 func TestTerminateAllPathContains(t *testing.T) { 139 procPath := procPath(t, "testTerminateAllPathContains") 140 defer util.RemoveFileAtPath(procPath) 141 procDir, procFile := filepath.Split(procPath) 142 match := procDir[1:] + procFile[:20] 143 t.Logf("Match: %q", match) 144 matcher := NewMatcher(match, PathContains, testLog) 145 testTerminateAll(t, procPath, matcher, 2) 146 } 147 148 func TestTerminateAllPathPrefix(t *testing.T) { 149 procPath := procPath(t, "testTerminateAllPathPrefix") 150 defer util.RemoveFileAtPath(procPath) 151 procDir, procFile := filepath.Split(procPath) 152 match := procDir + procFile[:20] 153 t.Logf("Match: %q", match) 154 matcher := NewMatcher(match, PathPrefix, testLog) 155 testTerminateAll(t, procPath, matcher, 2) 156 } 157 158 func testTerminateAll(t *testing.T, path string, matcher Matcher, numProcs int) { 159 var exitStatus string 160 if runtime.GOOS == "windows" { 161 exitStatus = "exit status 1" 162 } else { 163 exitStatus = "signal: terminated" 164 } 165 166 pids := []int{} 167 for i := 0; i < numProcs; i++ { 168 procPath, pid, cmd := startProcess(t, path, "sleep") 169 t.Logf("Started process %q (%d)", procPath, pid) 170 pids = append(pids, pid) 171 defer cleanupProc(cmd, "") 172 } 173 174 time.Sleep(time.Second) 175 176 terminatePids := TerminateAll(matcher, time.Second, testLog) 177 for _, p := range pids { 178 assert.Contains(t, terminatePids, p) 179 assertTerminated(t, p, exitStatus) 180 } 181 } 182 183 func TestFindProcessWait(t *testing.T) { 184 if runtime.GOOS == "windows" { 185 t.Skip("Skipping on windows") 186 } 187 procPath := procPath(t, "testFindProcessWait") 188 cmd := exec.Command(procPath, "sleep") 189 defer cleanupProc(cmd, procPath) 190 191 // Ensure it's not already running 192 procs, err := FindProcesses(NewMatcher(procPath, PathEqual, testLog), time.Millisecond, 0, testLog) 193 require.NoError(t, err) 194 require.Equal(t, 0, len(procs)) 195 196 go func() { 197 time.Sleep(10 * time.Millisecond) 198 startErr := cmd.Start() 199 require.NoError(t, startErr) 200 }() 201 202 // Wait up to second for process to be running 203 procs, err = FindProcesses(NewMatcher(procPath, PathEqual, testLog), time.Second, 10*time.Millisecond, testLog) 204 require.NoError(t, err) 205 require.True(t, len(procs) == 1) 206 }