github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/runner/runner.go (about) 1 // Copyright 2017 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 package runner 6 7 import ( 8 "context" 9 "io" 10 "io/ioutil" 11 "log" 12 "os" 13 "path/filepath" 14 "sort" 15 "strings" 16 "time" 17 18 "github.com/shirou/gopsutil/v3/process" 19 "golang.org/x/sys/unix" 20 21 "go.chromium.org/tast/core/errors" 22 "go.chromium.org/tast/core/internal/command" 23 "go.chromium.org/tast/core/internal/logging" 24 "go.chromium.org/tast/core/internal/protocol" 25 "go.chromium.org/tast/core/internal/testing" 26 ) 27 28 const ( 29 statusSuccess = 0 // runner was successful 30 _ = 1 // deprecated 31 statusBadArgs = 2 // bad arguments were passed to the runner 32 _ = 3 // deprecated 33 _ = 4 // deprecated 34 _ = 5 // deprecated 35 statusTestFailed = 6 // one or more tests failed during manual run 36 _ = 7 // deprecated 37 _ = 8 // deprecated 38 ) 39 40 // Run reads command-line flags from clArgs and performs the requested action. 41 // clArgs should typically be os.Args[1:]. The caller should exit with the 42 // returned status code. 43 func Run(clArgs []string, stdin io.Reader, stdout, stderr io.Writer, scfg *StaticConfig) int { 44 ctx := context.Background() 45 46 if scfg.EnableSyslog { 47 if l, err := logging.NewSyslogLogger(); err == nil { 48 defer l.Close() 49 ctx = logging.AttachLogger(ctx, l) 50 } 51 } 52 logging.Debug(ctx, "Tast local_runner starts") 53 defer logging.Debug(ctx, "Tast local_runner ends") 54 55 // TODO(b/189332919): Remove this hack and write stack traces to stderr 56 // once we finish migrating to gRPC-based protocol. This hack is needed 57 // because JSON-based protocol is designed to write messages to stderr 58 // in case of errors and thus Tast CLI consumes stderr. 59 if os.Getenv("TAST_B189332919_STACK_TRACE_FD") == "3" { 60 command.InstallSignalHandler(os.NewFile(3, ""), func(os.Signal) {}) 61 } 62 63 args, err := parseArgs(clArgs, stderr, scfg) 64 if err != nil { 65 return command.WriteError(stderr, err) 66 } 67 68 switch args.Mode { 69 case modeDeprecatedDirectRun: 70 if err := deprecatedDirectRun(ctx, &args.DeprecatedDirectRunConfig, scfg, stdout); err != nil { 71 return command.WriteError(stderr, err) 72 } 73 return statusSuccess 74 case modeRPC: 75 if err := runRPCServer(scfg, stdin, stdout); err != nil { 76 return command.WriteError(stderr, err) 77 } 78 return statusSuccess 79 default: 80 return command.WriteError(stderr, command.NewStatusErrorf(statusBadArgs, "invalid mode %v", args.Mode)) 81 } 82 } 83 84 func deprecatedDirectRun(ctx context.Context, drcfg *DeprecatedDirectRunConfig, scfg *StaticConfig, stdout io.Writer) error { 85 lg := log.New(stdout, "", log.LstdFlags) 86 87 matcher, err := testing.NewMatcher(drcfg.Patterns) 88 if err != nil { 89 return err 90 } 91 92 compat, err := startCompatServer(ctx, scfg, &protocol.HandshakeRequest{ 93 RunnerInitParams: &protocol.RunnerInitParams{ 94 BundleGlob: drcfg.BundleGlob, 95 }, 96 BundleInitParams: &protocol.BundleInitParams{}, 97 }) 98 if err != nil { 99 return err 100 } 101 defer compat.Close() 102 103 cl := compat.Client() 104 105 // Enumerate tests to run. 106 res, err := cl.ListEntities(ctx, &protocol.ListEntitiesRequest{Features: drcfg.RunConfig(nil).GetFeatures()}) 107 if err != nil { 108 return errors.Wrap(err, "failed to enumerate entities in bundles") 109 } 110 111 var testNames []string 112 for _, r := range res.Entities { 113 e := r.GetEntity() 114 if e.GetType() != protocol.EntityType_TEST { 115 continue 116 } 117 if matcher.Match(e.GetName(), e.GetAttributes()) { 118 testNames = append(testNames, e.GetName()) 119 } 120 } 121 sort.Strings(testNames) 122 123 // We expect to not match any tests if both local and remote tests are being run but the 124 // user specified a pattern that matched only local or only remote tests rather than tests 125 // of both types. Don't bother creating an out dir in that case. 126 if len(testNames) == 0 { 127 return errors.New("no tests matched") 128 } 129 130 rcfg := drcfg.RunConfig(testNames) 131 132 created, err := setUpBaseOutDir(rcfg) 133 if err != nil { 134 return errors.Wrap(err, "failed to set up base out dir") 135 } 136 // If the runner was executed manually and an out dir wasn't specified, clean up the temp dir that was created. 137 if created { 138 defer os.RemoveAll(rcfg.GetDirs().GetOutDir()) 139 } 140 141 // Call RunTests method and send the initial request. 142 srv, err := cl.RunTests(ctx) 143 if err != nil { 144 return errors.Wrap(err, "RunTests: failed to call") 145 } 146 147 initReq := &protocol.RunTestsRequest{Type: &protocol.RunTestsRequest_RunTestsInit{RunTestsInit: &protocol.RunTestsInit{RunConfig: rcfg}}} 148 if err := srv.Send(initReq); err != nil { 149 return errors.Wrap(err, "RunTests: failed to send initial request") 150 } 151 152 numTests := 0 153 testFailed := false // true if error seen for current test 154 var failedTests []string // names of tests with errors 155 var startTime, endTime time.Time // start of first test and end of last test 156 157 // Keep reading responses and convert them to control messages. 158 for { 159 res, err := srv.Recv() 160 if err == io.EOF { 161 lg.Printf("Ran %d test(s) in %v", numTests, endTime.Sub(startTime).Round(time.Millisecond)) 162 if len(failedTests) > 0 { 163 lg.Printf("%d failed:", len(failedTests)) 164 for _, t := range failedTests { 165 lg.Print(" " + t) 166 } 167 return command.NewStatusErrorf(statusTestFailed, "test(s) failed") 168 } 169 return nil 170 } 171 if err != nil { 172 return err 173 } 174 175 switch res := res.GetType().(type) { 176 case *protocol.RunTestsResponse_RunLog: 177 lg.Print(res.RunLog.GetText()) 178 case *protocol.RunTestsResponse_EntityStart: 179 lg.Print("Running ", res.EntityStart.GetEntity().GetName()) 180 testFailed = false 181 if numTests == 0 { 182 startTime = res.EntityStart.GetTime().AsTime() 183 } 184 case *protocol.RunTestsResponse_EntityLog: 185 lg.Print(res.EntityLog.GetText()) 186 case *protocol.RunTestsResponse_EntityError: 187 e := res.EntityError.GetError() 188 lg.Printf("Error: [%s:%d] %v", filepath.Base(e.GetLocation().GetFile()), e.GetLocation().GetLine(), e.GetReason()) 189 testFailed = true 190 case *protocol.RunTestsResponse_EntityEnd: 191 reasons := res.EntityEnd.GetSkip().GetReasons() 192 if len(reasons) > 0 { 193 lg.Printf("Skipped %s for missing deps: %s", res.EntityEnd.GetEntityName(), strings.Join(reasons, ", ")) 194 } else { 195 lg.Print("Finished ", res.EntityEnd.GetEntityName()) 196 } 197 lg.Print(strings.Repeat("-", 80)) 198 if testFailed { 199 failedTests = append(failedTests, res.EntityEnd.GetEntityName()) 200 } 201 numTests++ 202 endTime = res.EntityEnd.GetTime().AsTime() 203 } 204 } 205 } 206 207 // setUpBaseOutDir creates and assigns a temporary directory if rcfg.Dirs.OutDir is empty. 208 // It also ensures that the dir is accessible to all users. The returned boolean created 209 // indicates whether a new directory was created. 210 func setUpBaseOutDir(rcfg *protocol.RunConfig) (created bool, err error) { 211 defer func() { 212 if err != nil && created { 213 os.RemoveAll(rcfg.GetDirs().GetOutDir()) 214 created = false 215 } 216 }() 217 218 if rcfg.GetDirs().GetOutDir() == "" { 219 if rcfg.GetDirs().OutDir, err = ioutil.TempDir("", "tast_out."); err != nil { 220 return false, err 221 } 222 created = true 223 } else { 224 if _, err := os.Stat(rcfg.GetDirs().GetOutDir()); os.IsNotExist(err) { 225 if err := os.MkdirAll(rcfg.GetDirs().GetOutDir(), 0755); err != nil { 226 return false, err 227 } 228 created = true 229 } else if err != nil { 230 return false, err 231 } 232 } 233 234 // Make the directory traversable in case a test wants to write a file as another user. 235 // (Note that we can't guarantee that all the parent directories are also accessible, though.) 236 if err := os.Chmod(rcfg.GetDirs().GetOutDir(), 0755); err != nil { 237 return created, err 238 } 239 return created, nil 240 } 241 242 // killStaleRunners sends sig to the process groups of any other processes sharing 243 // the current process's executable. Status messages and errors are logged using lf. 244 func killStaleRunners(ctx context.Context, sig unix.Signal) { 245 ourPID := os.Getpid() 246 ourExe, err := os.Executable() 247 if err != nil { 248 logging.Info(ctx, "Failed to look up current executable: ", err) 249 return 250 } 251 252 procs, err := process.Processes() 253 if err != nil { 254 logging.Info(ctx, "Failed to list processes while looking for stale runners: ", err) 255 return 256 } 257 for _, proc := range procs { 258 if int(proc.Pid) == ourPID { 259 continue 260 } 261 if exe, err := proc.Exe(); err != nil || exe != ourExe { 262 continue 263 } 264 logging.Infof(ctx, "Sending signal %d to stale %v process group %d", sig, ourExe, proc.Pid) 265 if err := unix.Kill(int(-proc.Pid), sig); err != nil { 266 logging.Infof(ctx, "Failed killing process group %d: %v", proc.Pid, err) 267 } 268 } 269 }